diff --git a/miplearn/components/lazy/mem.py b/miplearn/components/lazy/mem.py index eb3fc63..a08d2d5 100644 --- a/miplearn/components/lazy/mem.py +++ b/miplearn/components/lazy/mem.py @@ -9,16 +9,10 @@ from sklearn.preprocessing import MultiLabelBinarizer from miplearn.extractors.abstract import FeaturesExtractor from miplearn.h5 import H5File -from miplearn.solvers.gurobi import GurobiModel +from miplearn.solvers.abstract import AbstractModel logger = logging.getLogger(__name__) -# TODO: Replace GurobiModel by AbstractModel -# TODO: fix_violations: remove model.inner -# TODO: fix_violations: remove `where` -# TODO: Write documentation -# TODO: Implement ExpertLazyConstrComponent - class MemorizingLazyConstrComponent: def __init__(self, clf: Any, extractor: FeaturesExtractor) -> None: @@ -79,13 +73,14 @@ class MemorizingLazyConstrComponent: def before_mip( self, test_h5: str, - model: GurobiModel, + model: AbstractModel, stats: Dict[str, Any], ) -> None: - assert self.constrs_ is not None if model.lazy_enforce is None: return + assert self.constrs_ is not None + # Read features with H5File(test_h5, "r") as h5: x_sample = self.extractor.get_instance_features(h5) @@ -101,5 +96,5 @@ class MemorizingLazyConstrComponent: # Enforce constraints violations = [self.constrs_[i] for (i, yi) in enumerate(y) if yi > 0.5] logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...") - model.lazy_enforce(model, violations, "aot") + model.lazy_enforce(model, violations) stats["Lazy Constraints: AOT"] = len(violations) diff --git a/miplearn/problems/tsp.py b/miplearn/problems/tsp.py index c11b057..dd910f4 100644 --- a/miplearn/problems/tsp.py +++ b/miplearn/problems/tsp.py @@ -159,13 +159,11 @@ def build_tsp_model(data: Union[str, TravelingSalesmanData]) -> GurobiModel: violations.append(cut_edges) return violations - def lazy_enforce(model: GurobiModel, violations: List[Any], where: str) -> None: + def lazy_enforce(model: GurobiModel, violations: List[Any]) -> None: for violation in violations: - constr = quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2 - if where == "cb": - model.inner.cbLazy(constr) - else: - model.inner.addConstr(constr) + model.add_constr( + quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2 + ) logger.info(f"tsp: added {len(violations)} subtour elimination constraints") model.update() diff --git a/miplearn/solvers/abstract.py b/miplearn/solvers/abstract.py index f37da41..cf985f7 100644 --- a/miplearn/solvers/abstract.py +++ b/miplearn/solvers/abstract.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. from abc import ABC, abstractmethod -from typing import Optional, Dict +from typing import Optional, Dict, Callable import numpy as np @@ -16,6 +16,10 @@ class AbstractModel(ABC): _supports_node_count = False _supports_solution_pool = False + def __init__(self) -> None: + self.lazy_enforce: Optional[Callable] = None + self.lazy_separate: Optional[Callable] = None + @abstractmethod def add_constrs( self, diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 90e1a20..63917d9 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -28,6 +28,7 @@ class GurobiModel(AbstractModel): self.lazy_enforce = lazy_enforce self.inner = inner self.lazy_constrs_: Optional[List[Any]] = None + self.where = "default" def add_constrs( self, @@ -53,6 +54,14 @@ class GurobiModel(AbstractModel): stats["Added constraints"] = 0 stats["Added constraints"] += nconstrs + def add_constr(self, constr: Any) -> None: + if self.where == "lazy": + self.inner.cbLazy(constr) + elif self.where == "cut": + self.inner.cbCut(constr) + else: + self.inner.addConstr(constr) + def extract_after_load(self, h5: H5File) -> None: """ Given a model that has just been loaded, extracts static problem @@ -132,9 +141,11 @@ class GurobiModel(AbstractModel): assert self.lazy_constrs_ is not None assert self.lazy_enforce is not None if where == GRB.Callback.MIPSOL: + self.where = "lazy" violations = self.lazy_separate(self) self.lazy_constrs_.extend(violations) - self.lazy_enforce(self, violations, "cb") + self.lazy_enforce(self, violations) + self.where = "default" if self.lazy_enforce is not None: self.inner.Params.lazyConstraints = 1 diff --git a/tests/fixtures/tsp-n20-00000.h5 b/tests/fixtures/tsp-n20-00000.h5 index d6bf57a..159a981 100644 Binary files a/tests/fixtures/tsp-n20-00000.h5 and b/tests/fixtures/tsp-n20-00000.h5 differ