Redesign InternalSolver constraint methods

This commit is contained in:
2021-04-10 15:46:53 -05:00
parent f70363db0d
commit 088d679f61
12 changed files with 160 additions and 221 deletions

View File

@@ -76,6 +76,7 @@ class DynamicLazyConstraintsComponent(Component):
instance: Instance,
model: Any,
) -> bool:
assert solver.internal_solver is not None
logger.debug("Finding violated lazy constraints...")
cids = instance.find_violated_lazy_constraints(solver.internal_solver, model)
if len(cids) == 0:

View File

@@ -53,8 +53,7 @@ class UserCutsComponent(Component):
cids = self.dynamic.sample_predict(instance, training_data)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
for cid in cids:
cobj = instance.build_user_cut(model, cid)
solver.internal_solver.add_constraint(cobj)
instance.enforce_user_cut(solver.internal_solver, model, cid)
stats["UserCuts: Added ahead-of-time"] = len(cids)
@overrides
@@ -73,9 +72,7 @@ class UserCutsComponent(Component):
if cid in self.enforced:
continue
assert isinstance(cid, Hashable)
cobj = instance.build_user_cut(model, cid)
assert cobj is not None
solver.internal_solver.add_cut(cobj)
instance.enforce_user_cut(solver.internal_solver, model, cid)
self.enforced.add(cid)
self.n_added_in_callback += 1
if len(cids) > 0:

View File

@@ -12,7 +12,7 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
from miplearn.components.component import Component
from miplearn.features import TrainingSample, Features
from miplearn.features import TrainingSample, Features, Constraint
from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
@@ -44,7 +44,7 @@ class StaticLazyConstraintsComponent(Component):
self.threshold_prototype: Threshold = threshold
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.pool: Dict[str, LazyConstraint] = {}
self.pool: Dict[str, Constraint] = {}
self.violation_tolerance: float = violation_tolerance
self.enforced_cids: Set[Hashable] = set()
self.n_restored: int = 0
@@ -72,10 +72,8 @@ class StaticLazyConstraintsComponent(Component):
self.pool = {}
for (cid, cdict) in features.constraints.items():
if cdict.lazy and cid not in self.enforced_cids:
self.pool[cid] = LazyConstraint(
cid=cid,
obj=solver.internal_solver.extract_constraint(cid),
)
self.pool[cid] = cdict
solver.internal_solver.remove_constraint(cid)
logger.info(
f"{len(self.enforced_cids)} lazy constraints kept; "
f"{len(self.pool)} moved to the pool"
@@ -124,18 +122,18 @@ class StaticLazyConstraintsComponent(Component):
def _check_and_add(self, solver: "LearningSolver") -> bool:
assert solver.internal_solver is not None
logger.info("Finding violated lazy constraints...")
enforced: List[LazyConstraint] = []
enforced: Dict[str, Constraint] = {}
for (cid, c) in self.pool.items():
if not solver.internal_solver.is_constraint_satisfied(
c.obj,
c,
tol=self.violation_tolerance,
):
enforced.append(c)
enforced[cid] = c
logger.info(f"{len(enforced)} violations found")
for c in enforced:
del self.pool[c.cid]
solver.internal_solver.add_constraint(c.obj)
self.enforced_cids.add(c.cid)
for (cid, c) in enforced.items():
del self.pool[cid]
solver.internal_solver.add_constraint(c, name=cid)
self.enforced_cids.add(cid)
self.n_restored += 1
logger.info(
f"{len(enforced)} constraints restored; {len(self.pool)} in the pool"