mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Refactor StaticLazy; remove old constraint methods
This commit is contained in:
@@ -12,7 +12,7 @@ from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.features import Constraint, Sample
|
||||
from miplearn.features import Constraint, Sample, ConstraintFeatures
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import LearningSolveStats
|
||||
|
||||
@@ -45,7 +45,8 @@ class StaticLazyConstraintsComponent(Component):
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.pool: Dict[str, Constraint] = {}
|
||||
self.pool_old: Dict[str, Constraint] = {}
|
||||
self.pool: ConstraintFeatures = ConstraintFeatures()
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[Hashable] = set()
|
||||
self.n_restored: int = 0
|
||||
@@ -78,24 +79,28 @@ class StaticLazyConstraintsComponent(Component):
|
||||
assert solver.internal_solver is not None
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.instance is not None
|
||||
assert sample.after_load.constraints_old is not None
|
||||
|
||||
logger.info("Predicting violated (static) lazy constraints...")
|
||||
if sample.after_load.instance.lazy_constraint_count == 0:
|
||||
logger.info("Instance does not have static lazy constraints. Skipping.")
|
||||
self.enforced_cids = set(self.sample_predict(sample))
|
||||
logger.info("Moving lazy constraints to the pool...")
|
||||
self.pool = {}
|
||||
for (cid, cdict) in sample.after_load.constraints_old.items():
|
||||
if cdict.lazy and cid not in self.enforced_cids:
|
||||
self.pool[cid] = cdict
|
||||
solver.internal_solver.remove_constraint(cid)
|
||||
logger.info(
|
||||
f"{len(self.enforced_cids)} lazy constraints kept; "
|
||||
f"{len(self.pool)} moved to the pool"
|
||||
constraints = sample.after_load.constraints
|
||||
assert constraints is not None
|
||||
assert constraints.lazy is not None
|
||||
assert constraints.names is not None
|
||||
selected = tuple(
|
||||
(constraints.lazy[i] and constraints.names[i] not in self.enforced_cids)
|
||||
for i in range(len(constraints.lazy))
|
||||
)
|
||||
stats["LazyStatic: Removed"] = len(self.pool)
|
||||
stats["LazyStatic: Kept"] = len(self.enforced_cids)
|
||||
n_removed = sum(selected)
|
||||
n_kept = sum(constraints.lazy) - n_removed
|
||||
self.pool = constraints[selected]
|
||||
assert self.pool.names is not None
|
||||
solver.internal_solver.remove_constraints(self.pool.names)
|
||||
logger.info(f"{n_kept} lazy constraints kept; {n_removed} moved to the pool")
|
||||
stats["LazyStatic: Removed"] = n_removed
|
||||
stats["LazyStatic: Kept"] = n_kept
|
||||
stats["LazyStatic: Restored"] = 0
|
||||
self.n_restored = 0
|
||||
self.n_iterations = 0
|
||||
@@ -160,25 +165,34 @@ class StaticLazyConstraintsComponent(Component):
|
||||
|
||||
def _check_and_add(self, solver: "LearningSolver") -> bool:
|
||||
assert solver.internal_solver is not None
|
||||
assert self.pool.names is not None
|
||||
if len(self.pool.names) == 0:
|
||||
logger.info("Lazy constraint pool is empty. Skipping violation check.")
|
||||
return False
|
||||
self.n_iterations += 1
|
||||
logger.info("Finding violated lazy constraints...")
|
||||
enforced: Dict[str, Constraint] = {}
|
||||
for (cid, c) in self.pool.items():
|
||||
if not solver.internal_solver.is_constraint_satisfied_old(
|
||||
c,
|
||||
tol=self.violation_tolerance,
|
||||
):
|
||||
enforced[cid] = c
|
||||
logger.info(f"{len(enforced)} violations found")
|
||||
for (cid, c) in enforced.items():
|
||||
del self.pool[cid]
|
||||
solver.internal_solver.add_constraint(c, name=cid)
|
||||
self.enforced_cids.add(cid)
|
||||
self.n_restored += 1
|
||||
logger.info(
|
||||
f"{len(enforced)} constraints restored; {len(self.pool)} in the pool"
|
||||
is_satisfied = solver.internal_solver.are_constraints_satisfied(
|
||||
self.pool,
|
||||
tol=self.violation_tolerance,
|
||||
)
|
||||
if len(enforced) > 0:
|
||||
self.n_iterations += 1
|
||||
is_violated = tuple(not i for i in is_satisfied)
|
||||
violated_constraints = self.pool[is_violated]
|
||||
satisfied_constraints = self.pool[is_satisfied]
|
||||
self.pool = satisfied_constraints
|
||||
assert violated_constraints.names is not None
|
||||
assert satisfied_constraints.names is not None
|
||||
n_violated = len(violated_constraints.names)
|
||||
n_satisfied = len(satisfied_constraints.names)
|
||||
logger.info(f"Found {n_violated} violated lazy constraints found")
|
||||
if n_violated > 0:
|
||||
logger.info(
|
||||
"Enforcing {n_violated} lazy constraints; "
|
||||
f"{n_satisfied} left in the pool..."
|
||||
)
|
||||
solver.internal_solver.add_constraints(violated_constraints)
|
||||
for (i, name) in enumerate(violated_constraints.names):
|
||||
self.enforced_cids.add(name)
|
||||
self.n_restored += 1
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@@ -194,12 +208,16 @@ class StaticLazyConstraintsComponent(Component):
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
cids: Dict[Hashable, List[str]] = {}
|
||||
assert sample.after_load is not None
|
||||
assert sample.after_load.constraints_old is not None
|
||||
for (cid, constr) in sample.after_load.constraints_old.items():
|
||||
constraints = sample.after_load.constraints
|
||||
assert constraints is not None
|
||||
assert constraints.names is not None
|
||||
assert constraints.lazy is not None
|
||||
assert constraints.categories is not None
|
||||
for (cidx, cname) in enumerate(constraints.names):
|
||||
# Initialize categories
|
||||
if not constr.lazy:
|
||||
if not constraints.lazy[cidx]:
|
||||
continue
|
||||
category = constr.category
|
||||
category = constraints.categories[cidx]
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
@@ -212,12 +230,11 @@ class StaticLazyConstraintsComponent(Component):
|
||||
if sample.after_lp is not None:
|
||||
sf = sample.after_lp
|
||||
assert sf.instance is not None
|
||||
assert sf.constraints is not None
|
||||
features = list(sf.instance.to_list())
|
||||
assert sf.constraints_old is not None
|
||||
assert sf.constraints_old[cid] is not None
|
||||
features.extend(sf.constraints_old[cid].to_list())
|
||||
features.extend(sf.constraints.to_list(cidx))
|
||||
x[category].append(features)
|
||||
cids[category].append(cid)
|
||||
cids[category].append(cname)
|
||||
|
||||
# Labels
|
||||
if (
|
||||
@@ -225,7 +242,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
and (sample.after_mip.extra is not None)
|
||||
and ("lazy_enforced" in sample.after_mip.extra)
|
||||
):
|
||||
if cid in sample.after_mip.extra["lazy_enforced"]:
|
||||
if cname in sample.after_mip.extra["lazy_enforced"]:
|
||||
y[category] += [[False, True]]
|
||||
else:
|
||||
y[category] += [[True, False]]
|
||||
|
||||
Reference in New Issue
Block a user