diff --git a/miplearn/problems/tsp.py b/miplearn/problems/tsp.py index 209d24c..eb494e1 100644 --- a/miplearn/problems/tsp.py +++ b/miplearn/problems/tsp.py @@ -149,21 +149,8 @@ def build_tsp_model_gurobipy( ) def lazy_separate(model: GurobiModel) -> List[Any]: - violations = [] - x = model.inner.cbGetSolution(model.inner._x) - selected_edges = [e for e in model.inner._edges if x[e] > 0.5] - graph = nx.Graph() - graph.add_edges_from(selected_edges) - for component in list(nx.connected_components(graph)): - if len(component) < model.inner._n_cities: - cut_edges = tuple( - (e[0], e[1]) - for e in model.inner._edges - if (e[0] in component and e[1] not in component) - or (e[0] not in component and e[1] in component) - ) - violations.append(cut_edges) - return violations + x_val = model.inner.cbGetSolution(model.inner._x) + return _tsp_separate(x_val, edges, data.n_cities) def lazy_enforce(model: GurobiModel, violations: List[Any]) -> None: for violation in violations: @@ -212,22 +199,9 @@ def build_tsp_model_pyomo( model.subtour_eqs = pe.ConstraintList() def lazy_separate(m: PyomoModel) -> List[Any]: - violations = [] m.solver.cbGetSolution([model.x[e] for e in edges]) x_val = {e: model.x[e].value for e in edges} - selected_edges = [e for e in edges if x_val[e] > 0.5] - graph = nx.Graph() - graph.add_edges_from(selected_edges) - for component in list(nx.connected_components(graph)): - if len(component) < data.n_cities: - cut_edges = tuple( - (e[0], e[1]) - for e in edges - if (e[0] in component and e[1] not in component) - or (e[0] not in component and e[1] in component) - ) - violations.append(cut_edges) - return violations + return _tsp_separate(x_val, edges, data.n_cities) def lazy_enforce(m: PyomoModel, violations: List[Any]) -> None: logger.warning(f"Adding {len(violations)} subtour elimination constraints...") @@ -251,3 +225,24 @@ def _tsp_read(data: Union[str, TravelingSalesmanData]) -> TravelingSalesmanData: data = read_pkl_gz(data) assert isinstance(data, TravelingSalesmanData) return data + + +def _tsp_separate( + x_val: dict[Tuple[int, int], float], + edges: List[Tuple[int, int]], + n_cities: int, +) -> List[Tuple[Tuple[int, int], ...]]: + violations = [] + selected_edges = [e for e in edges if x_val[e] > 0.5] + graph = nx.Graph() + graph.add_edges_from(selected_edges) + for component in list(nx.connected_components(graph)): + if len(component) < n_cities: + cut_edges = tuple( + (e[0], e[1]) + for e in edges + if (e[0] in component and e[1] not in component) + or (e[0] not in component and e[1] in component) + ) + violations.append(cut_edges) + return violations