From a42cd5ae35b870e6eef73bb81e086407afa42285 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Fri, 27 Oct 2023 09:14:51 -0500 Subject: [PATCH] Lazy: Simplify method signature; switch to AbstractModel --- miplearn/components/lazy/mem.py | 15 +++++---------- miplearn/problems/tsp.py | 10 ++++------ miplearn/solvers/abstract.py | 6 +++++- miplearn/solvers/gurobi.py | 13 ++++++++++++- tests/fixtures/tsp-n20-00000.h5 | Bin 37243 -> 65915 bytes 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/miplearn/components/lazy/mem.py b/miplearn/components/lazy/mem.py index eb3fc63..a08d2d5 100644 --- a/miplearn/components/lazy/mem.py +++ b/miplearn/components/lazy/mem.py @@ -9,16 +9,10 @@ from sklearn.preprocessing import MultiLabelBinarizer from miplearn.extractors.abstract import FeaturesExtractor from miplearn.h5 import H5File -from miplearn.solvers.gurobi import GurobiModel +from miplearn.solvers.abstract import AbstractModel logger = logging.getLogger(__name__) -# TODO: Replace GurobiModel by AbstractModel -# TODO: fix_violations: remove model.inner -# TODO: fix_violations: remove `where` -# TODO: Write documentation -# TODO: Implement ExpertLazyConstrComponent - class MemorizingLazyConstrComponent: def __init__(self, clf: Any, extractor: FeaturesExtractor) -> None: @@ -79,13 +73,14 @@ class MemorizingLazyConstrComponent: def before_mip( self, test_h5: str, - model: GurobiModel, + model: AbstractModel, stats: Dict[str, Any], ) -> None: - assert self.constrs_ is not None if model.lazy_enforce is None: return + assert self.constrs_ is not None + # Read features with H5File(test_h5, "r") as h5: x_sample = self.extractor.get_instance_features(h5) @@ -101,5 +96,5 @@ class MemorizingLazyConstrComponent: # Enforce constraints violations = [self.constrs_[i] for (i, yi) in enumerate(y) if yi > 0.5] logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...") - model.lazy_enforce(model, violations, "aot") + model.lazy_enforce(model, violations) stats["Lazy Constraints: AOT"] = len(violations) diff --git a/miplearn/problems/tsp.py b/miplearn/problems/tsp.py index c11b057..dd910f4 100644 --- a/miplearn/problems/tsp.py +++ b/miplearn/problems/tsp.py @@ -159,13 +159,11 @@ def build_tsp_model(data: Union[str, TravelingSalesmanData]) -> GurobiModel: violations.append(cut_edges) return violations - def lazy_enforce(model: GurobiModel, violations: List[Any], where: str) -> None: + def lazy_enforce(model: GurobiModel, violations: List[Any]) -> None: for violation in violations: - constr = quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2 - if where == "cb": - model.inner.cbLazy(constr) - else: - model.inner.addConstr(constr) + model.add_constr( + quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2 + ) logger.info(f"tsp: added {len(violations)} subtour elimination constraints") model.update() diff --git a/miplearn/solvers/abstract.py b/miplearn/solvers/abstract.py index f37da41..cf985f7 100644 --- a/miplearn/solvers/abstract.py +++ b/miplearn/solvers/abstract.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. from abc import ABC, abstractmethod -from typing import Optional, Dict +from typing import Optional, Dict, Callable import numpy as np @@ -16,6 +16,10 @@ class AbstractModel(ABC): _supports_node_count = False _supports_solution_pool = False + def __init__(self) -> None: + self.lazy_enforce: Optional[Callable] = None + self.lazy_separate: Optional[Callable] = None + @abstractmethod def add_constrs( self, diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 90e1a20..63917d9 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -28,6 +28,7 @@ class GurobiModel(AbstractModel): self.lazy_enforce = lazy_enforce self.inner = inner self.lazy_constrs_: Optional[List[Any]] = None + self.where = "default" def add_constrs( self, @@ -53,6 +54,14 @@ class GurobiModel(AbstractModel): stats["Added constraints"] = 0 stats["Added constraints"] += nconstrs + def add_constr(self, constr: Any) -> None: + if self.where == "lazy": + self.inner.cbLazy(constr) + elif self.where == "cut": + self.inner.cbCut(constr) + else: + self.inner.addConstr(constr) + def extract_after_load(self, h5: H5File) -> None: """ Given a model that has just been loaded, extracts static problem @@ -132,9 +141,11 @@ class GurobiModel(AbstractModel): assert self.lazy_constrs_ is not None assert self.lazy_enforce is not None if where == GRB.Callback.MIPSOL: + self.where = "lazy" violations = self.lazy_separate(self) self.lazy_constrs_.extend(violations) - self.lazy_enforce(self, violations, "cb") + self.lazy_enforce(self, violations) + self.where = "default" if self.lazy_enforce is not None: self.inner.Params.lazyConstraints = 1 diff --git a/tests/fixtures/tsp-n20-00000.h5 b/tests/fixtures/tsp-n20-00000.h5 index d6bf57ab7a095d3ac12551f8effc97d8c2848d21..159a981087d348fc8259e08f4afc26ce47b1bf96 100644 GIT binary patch delta 1350 zcmeypi0O9|%LG|vMn(n@FaVPb3}OkHx&9t5K}-s9N3By!4_K$h9i8Y+>sTt*KdzW8&SJucumEHk z(2}sp>si*ZESnRtcXA@D7dCl*wskDA=5b-053+G{Fjarte2&k85nZ8ypzLHrf#(vN z*Z;YU%_W-+1X&m*3Vm0#V3VB;(w-^wT*CkU^(9yp{1FS7yhC^k%Ta#m;>`sjbxddu z*!)KfXuiQ0foL4sWha-h=uNg4=bc<2u|`7Z$@Sk@6$(@tOiq!!!op_2rm@*ZicJ9B zo&YUYM)4iDim~d3#F_5oZQ8peqPmuI;7}Ac`JT=?2{zFy$8jq7rMpg|a@p~8Yzj6X z&;$DQx0%xmY_i}uH8NNyv22M_7fuC!hU-|~DQZ`2GP=pirR8l8N~IHyZQcCXJRT!x z1FV47O`p1l6Q^m1tk+3I@JY7eRB+E`oy4DuM>b(q0Ev9d$-ivZNlf?2#2PW157+@6 zacS;ddmKiDP4;u-nQZ5lF!E*NxbdDaeE^ZdW80RbKpv&4kSh+<^{=3ZIEp z`{q3!K-puj|82ph7wkL^udCwwUmnIOui|}0Vv5#vtf_mmfDh1kwYgiLVAH-?#}_DD oIm2ltR#^`}7DkrJpG#hD4)E7@nrtJbvDu;V2LED-MgeRX0OaouS^xk5 delta 1316 zcmey}#PWL)(*#-Ni40(10HGL6Z=X8m@8J@}q)>OjI<*vt>kdqGsF&EboiPb4u2>5c zL*U5=^kgRsFtJbk{Xn8{VgNrLHP0n}rN#8&P!Kjbm3bY@tShT}CX2I}@F6S!Sq8Ku zZ1Q@Rbu7~Rqg*B@vU*{Y=Vx2T;vs#jZ1X`jZVo1i#?9yWEEv%h8VJfxHWYX+(OV$= z5SvRj8wj#6O03T1@5ClM8KgZ^=()t?xA#6`Rq#hFVDb*(Ei5lW#1C&S5UFEAbHL_5 zVnFkqkA1#~L%ZzcG8Vnb_Ts#g3nbQv?|t~^J{E;$kEM@&Eqjb zHoyvGog~*{oTeSJUMHbnqr3&Df_paWB%EeXu7xN7WdKk*gwW8qx19XTcAdnD-f|~| zB9H z7q|l*z-iU>2CM8I51?%H>18Xi$%387;dNC!;q_#k@+#g}B+@b^FJn`%S-=Nq{I``! gqF7~he1WnT-y3enD(m6L!pLI1-tW!k0Do;K0PMdFNB{r;