Lazy: Simplify method signature; switch to AbstractModel

dev
Alinson S. Xavier 2 years ago
parent 7079a36203
commit a42cd5ae35
Signed by: isoron
GPG Key ID: 0DA8E4B9E1109DCA

@ -9,16 +9,10 @@ from sklearn.preprocessing import MultiLabelBinarizer
from miplearn.extractors.abstract import FeaturesExtractor
from miplearn.h5 import H5File
from miplearn.solvers.gurobi import GurobiModel
from miplearn.solvers.abstract import AbstractModel
logger = logging.getLogger(__name__)
# TODO: Replace GurobiModel by AbstractModel
# TODO: fix_violations: remove model.inner
# TODO: fix_violations: remove `where`
# TODO: Write documentation
# TODO: Implement ExpertLazyConstrComponent
class MemorizingLazyConstrComponent:
def __init__(self, clf: Any, extractor: FeaturesExtractor) -> None:
@ -79,13 +73,14 @@ class MemorizingLazyConstrComponent:
def before_mip(
self,
test_h5: str,
model: GurobiModel,
model: AbstractModel,
stats: Dict[str, Any],
) -> None:
assert self.constrs_ is not None
if model.lazy_enforce is None:
return
assert self.constrs_ is not None
# Read features
with H5File(test_h5, "r") as h5:
x_sample = self.extractor.get_instance_features(h5)
@ -101,5 +96,5 @@ class MemorizingLazyConstrComponent:
# Enforce constraints
violations = [self.constrs_[i] for (i, yi) in enumerate(y) if yi > 0.5]
logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...")
model.lazy_enforce(model, violations, "aot")
model.lazy_enforce(model, violations)
stats["Lazy Constraints: AOT"] = len(violations)

@ -159,13 +159,11 @@ def build_tsp_model(data: Union[str, TravelingSalesmanData]) -> GurobiModel:
violations.append(cut_edges)
return violations
def lazy_enforce(model: GurobiModel, violations: List[Any], where: str) -> None:
def lazy_enforce(model: GurobiModel, violations: List[Any]) -> None:
for violation in violations:
constr = quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2
if where == "cb":
model.inner.cbLazy(constr)
else:
model.inner.addConstr(constr)
model.add_constr(
quicksum(model.inner._x[e[0], e[1]] for e in violation) >= 2
)
logger.info(f"tsp: added {len(violations)} subtour elimination constraints")
model.update()

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod
from typing import Optional, Dict
from typing import Optional, Dict, Callable
import numpy as np
@ -16,6 +16,10 @@ class AbstractModel(ABC):
_supports_node_count = False
_supports_solution_pool = False
def __init__(self) -> None:
self.lazy_enforce: Optional[Callable] = None
self.lazy_separate: Optional[Callable] = None
@abstractmethod
def add_constrs(
self,

@ -28,6 +28,7 @@ class GurobiModel(AbstractModel):
self.lazy_enforce = lazy_enforce
self.inner = inner
self.lazy_constrs_: Optional[List[Any]] = None
self.where = "default"
def add_constrs(
self,
@ -53,6 +54,14 @@ class GurobiModel(AbstractModel):
stats["Added constraints"] = 0
stats["Added constraints"] += nconstrs
def add_constr(self, constr: Any) -> None:
if self.where == "lazy":
self.inner.cbLazy(constr)
elif self.where == "cut":
self.inner.cbCut(constr)
else:
self.inner.addConstr(constr)
def extract_after_load(self, h5: H5File) -> None:
"""
Given a model that has just been loaded, extracts static problem
@ -132,9 +141,11 @@ class GurobiModel(AbstractModel):
assert self.lazy_constrs_ is not None
assert self.lazy_enforce is not None
if where == GRB.Callback.MIPSOL:
self.where = "lazy"
violations = self.lazy_separate(self)
self.lazy_constrs_.extend(violations)
self.lazy_enforce(self, violations, "cb")
self.lazy_enforce(self, violations)
self.where = "default"
if self.lazy_enforce is not None:
self.inner.Params.lazyConstraints = 1

Binary file not shown.
Loading…
Cancel
Save