mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Redesign InternalSolver constraint methods
This commit is contained in:
@@ -76,6 +76,7 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
) -> bool:
|
||||
assert solver.internal_solver is not None
|
||||
logger.debug("Finding violated lazy constraints...")
|
||||
cids = instance.find_violated_lazy_constraints(solver.internal_solver, model)
|
||||
if len(cids) == 0:
|
||||
|
||||
@@ -53,8 +53,7 @@ class UserCutsComponent(Component):
|
||||
cids = self.dynamic.sample_predict(instance, training_data)
|
||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
||||
for cid in cids:
|
||||
cobj = instance.build_user_cut(model, cid)
|
||||
solver.internal_solver.add_constraint(cobj)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
stats["UserCuts: Added ahead-of-time"] = len(cids)
|
||||
|
||||
@overrides
|
||||
@@ -73,9 +72,7 @@ class UserCutsComponent(Component):
|
||||
if cid in self.enforced:
|
||||
continue
|
||||
assert isinstance(cid, Hashable)
|
||||
cobj = instance.build_user_cut(model, cid)
|
||||
assert cobj is not None
|
||||
solver.internal_solver.add_cut(cobj)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
self.enforced.add(cid)
|
||||
self.n_added_in_callback += 1
|
||||
if len(cids) > 0:
|
||||
|
||||
@@ -12,7 +12,7 @@ from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.features import TrainingSample, Features
|
||||
from miplearn.features import TrainingSample, Features, Constraint
|
||||
from miplearn.types import LearningSolveStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,7 +44,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.pool: Dict[str, LazyConstraint] = {}
|
||||
self.pool: Dict[str, Constraint] = {}
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[Hashable] = set()
|
||||
self.n_restored: int = 0
|
||||
@@ -72,10 +72,8 @@ class StaticLazyConstraintsComponent(Component):
|
||||
self.pool = {}
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if cdict.lazy and cid not in self.enforced_cids:
|
||||
self.pool[cid] = LazyConstraint(
|
||||
cid=cid,
|
||||
obj=solver.internal_solver.extract_constraint(cid),
|
||||
)
|
||||
self.pool[cid] = cdict
|
||||
solver.internal_solver.remove_constraint(cid)
|
||||
logger.info(
|
||||
f"{len(self.enforced_cids)} lazy constraints kept; "
|
||||
f"{len(self.pool)} moved to the pool"
|
||||
@@ -124,18 +122,18 @@ class StaticLazyConstraintsComponent(Component):
|
||||
def _check_and_add(self, solver: "LearningSolver") -> bool:
|
||||
assert solver.internal_solver is not None
|
||||
logger.info("Finding violated lazy constraints...")
|
||||
enforced: List[LazyConstraint] = []
|
||||
enforced: Dict[str, Constraint] = {}
|
||||
for (cid, c) in self.pool.items():
|
||||
if not solver.internal_solver.is_constraint_satisfied(
|
||||
c.obj,
|
||||
c,
|
||||
tol=self.violation_tolerance,
|
||||
):
|
||||
enforced.append(c)
|
||||
enforced[cid] = c
|
||||
logger.info(f"{len(enforced)} violations found")
|
||||
for c in enforced:
|
||||
del self.pool[c.cid]
|
||||
solver.internal_solver.add_constraint(c.obj)
|
||||
self.enforced_cids.add(c.cid)
|
||||
for (cid, c) in enforced.items():
|
||||
del self.pool[cid]
|
||||
solver.internal_solver.add_constraint(c, name=cid)
|
||||
self.enforced_cids.add(cid)
|
||||
self.n_restored += 1
|
||||
logger.info(
|
||||
f"{len(enforced)} constraints restored; {len(self.pool)} in the pool"
|
||||
|
||||
Reference in New Issue
Block a user