Refactor StaticLazy; remove old constraint methods

This commit is contained in:
2021-05-15 14:15:48 -05:00
parent 53d3e9d98a
commit 91c8db2225
10 changed files with 343 additions and 583 deletions

View File

@@ -12,7 +12,7 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
from miplearn.components.component import Component
from miplearn.features import Constraint, Sample
from miplearn.features import Constraint, Sample, ConstraintFeatures
from miplearn.instance.base import Instance
from miplearn.types import LearningSolveStats
@@ -45,7 +45,8 @@ class StaticLazyConstraintsComponent(Component):
self.threshold_prototype: Threshold = threshold
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.pool: Dict[str, Constraint] = {}
self.pool_old: Dict[str, Constraint] = {}
self.pool: ConstraintFeatures = ConstraintFeatures()
self.violation_tolerance: float = violation_tolerance
self.enforced_cids: Set[Hashable] = set()
self.n_restored: int = 0
@@ -78,24 +79,28 @@ class StaticLazyConstraintsComponent(Component):
assert solver.internal_solver is not None
assert sample.after_load is not None
assert sample.after_load.instance is not None
assert sample.after_load.constraints_old is not None
logger.info("Predicting violated (static) lazy constraints...")
if sample.after_load.instance.lazy_constraint_count == 0:
logger.info("Instance does not have static lazy constraints. Skipping.")
self.enforced_cids = set(self.sample_predict(sample))
logger.info("Moving lazy constraints to the pool...")
self.pool = {}
for (cid, cdict) in sample.after_load.constraints_old.items():
if cdict.lazy and cid not in self.enforced_cids:
self.pool[cid] = cdict
solver.internal_solver.remove_constraint(cid)
logger.info(
f"{len(self.enforced_cids)} lazy constraints kept; "
f"{len(self.pool)} moved to the pool"
constraints = sample.after_load.constraints
assert constraints is not None
assert constraints.lazy is not None
assert constraints.names is not None
selected = tuple(
(constraints.lazy[i] and constraints.names[i] not in self.enforced_cids)
for i in range(len(constraints.lazy))
)
stats["LazyStatic: Removed"] = len(self.pool)
stats["LazyStatic: Kept"] = len(self.enforced_cids)
n_removed = sum(selected)
n_kept = sum(constraints.lazy) - n_removed
self.pool = constraints[selected]
assert self.pool.names is not None
solver.internal_solver.remove_constraints(self.pool.names)
logger.info(f"{n_kept} lazy constraints kept; {n_removed} moved to the pool")
stats["LazyStatic: Removed"] = n_removed
stats["LazyStatic: Kept"] = n_kept
stats["LazyStatic: Restored"] = 0
self.n_restored = 0
self.n_iterations = 0
@@ -160,25 +165,34 @@ class StaticLazyConstraintsComponent(Component):
def _check_and_add(self, solver: "LearningSolver") -> bool:
assert solver.internal_solver is not None
assert self.pool.names is not None
if len(self.pool.names) == 0:
logger.info("Lazy constraint pool is empty. Skipping violation check.")
return False
self.n_iterations += 1
logger.info("Finding violated lazy constraints...")
enforced: Dict[str, Constraint] = {}
for (cid, c) in self.pool.items():
if not solver.internal_solver.is_constraint_satisfied_old(
c,
tol=self.violation_tolerance,
):
enforced[cid] = c
logger.info(f"{len(enforced)} violations found")
for (cid, c) in enforced.items():
del self.pool[cid]
solver.internal_solver.add_constraint(c, name=cid)
self.enforced_cids.add(cid)
self.n_restored += 1
logger.info(
f"{len(enforced)} constraints restored; {len(self.pool)} in the pool"
is_satisfied = solver.internal_solver.are_constraints_satisfied(
self.pool,
tol=self.violation_tolerance,
)
if len(enforced) > 0:
self.n_iterations += 1
is_violated = tuple(not i for i in is_satisfied)
violated_constraints = self.pool[is_violated]
satisfied_constraints = self.pool[is_satisfied]
self.pool = satisfied_constraints
assert violated_constraints.names is not None
assert satisfied_constraints.names is not None
n_violated = len(violated_constraints.names)
n_satisfied = len(satisfied_constraints.names)
logger.info(f"Found {n_violated} violated lazy constraints found")
if n_violated > 0:
logger.info(
"Enforcing {n_violated} lazy constraints; "
f"{n_satisfied} left in the pool..."
)
solver.internal_solver.add_constraints(violated_constraints)
for (i, name) in enumerate(violated_constraints.names):
self.enforced_cids.add(name)
self.n_restored += 1
return True
else:
return False
@@ -194,12 +208,16 @@ class StaticLazyConstraintsComponent(Component):
y: Dict[Hashable, List[List[float]]] = {}
cids: Dict[Hashable, List[str]] = {}
assert sample.after_load is not None
assert sample.after_load.constraints_old is not None
for (cid, constr) in sample.after_load.constraints_old.items():
constraints = sample.after_load.constraints
assert constraints is not None
assert constraints.names is not None
assert constraints.lazy is not None
assert constraints.categories is not None
for (cidx, cname) in enumerate(constraints.names):
# Initialize categories
if not constr.lazy:
if not constraints.lazy[cidx]:
continue
category = constr.category
category = constraints.categories[cidx]
if category is None:
continue
if category not in x:
@@ -212,12 +230,11 @@ class StaticLazyConstraintsComponent(Component):
if sample.after_lp is not None:
sf = sample.after_lp
assert sf.instance is not None
assert sf.constraints is not None
features = list(sf.instance.to_list())
assert sf.constraints_old is not None
assert sf.constraints_old[cid] is not None
features.extend(sf.constraints_old[cid].to_list())
features.extend(sf.constraints.to_list(cidx))
x[category].append(features)
cids[category].append(cid)
cids[category].append(cname)
# Labels
if (
@@ -225,7 +242,7 @@ class StaticLazyConstraintsComponent(Component):
and (sample.after_mip.extra is not None)
and ("lazy_enforced" in sample.after_mip.extra)
):
if cid in sample.after_mip.extra["lazy_enforced"]:
if cname in sample.after_mip.extra["lazy_enforced"]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]