Small fixes to lazy constraints

This commit is contained in:
2020-09-24 14:30:29 -05:00
parent 0fe6aab98f
commit ba96338d2d
7 changed files with 58 additions and 39 deletions

View File

@@ -28,6 +28,7 @@ class StaticLazyConstraintsComponent(Component):
self.pool = []
def before_solve(self, solver, instance, model):
self.pool = []
instance.found_violated_lazy_constraints = []
if instance.has_static_lazy_constraints():
self._extract_and_predict_static(solver, instance)
@@ -36,21 +37,28 @@ class StaticLazyConstraintsComponent(Component):
pass
def after_iteration(self, solver, instance, model):
logger.debug("Finding violated (static) lazy constraints...")
n_added = 0
logger.info("Finding violated lazy constraints...")
constraints_to_add = []
for c in self.pool:
if not solver.internal_solver.is_constraint_satisfied(c.obj):
self.pool.remove(c)
solver.internal_solver.add_constraint(c.obj)
instance.found_violated_lazy_constraints += [c.cid]
n_added += 1
if n_added > 0:
logger.debug(" %d violations found" % n_added)
constraints_to_add.append(c)
for c in constraints_to_add:
self.pool.remove(c)
solver.internal_solver.add_constraint(c.obj)
instance.found_violated_lazy_constraints += [c.cid]
if len(constraints_to_add) > 0:
logger.info("Added %d lazy constraints back into the model" % len(constraints_to_add))
logger.info("Lazy constraint pool has %d constraints" % len(self.pool))
return True
else:
logger.info("Found no violated lazy constraints")
return False
def fit(self, training_instances):
training_instances = [t
for t in training_instances
if hasattr(t, "found_violated_lazy_constraints")]
logger.debug("Extracting x and y...")
x = self.x(training_instances)
y = self.y(training_instances)
@@ -72,11 +80,10 @@ class StaticLazyConstraintsComponent(Component):
def _extract_and_predict_static(self, solver, instance):
x = {}
constraints = {}
for cid in solver.internal_solver.get_constraint_names():
logger.info("Extracting lazy constraints...")
for cid in solver.internal_solver.get_constraint_ids():
if instance.is_constraint_lazy(cid):
category = instance.get_lazy_constraint_category(cid)
if category not in self.classifiers:
continue
if category not in x:
x[category] = []
constraints[category] = []
@@ -85,16 +92,24 @@ class StaticLazyConstraintsComponent(Component):
obj=solver.internal_solver.extract_constraint(cid))
constraints[category] += [c]
self.pool.append(c)
logger.info("Extracted %d lazy constraints" % len(self.pool))
logger.info("Predicting required lazy constraints...")
n_added = 0
for (category, x_values) in x.items():
if category not in self.classifiers:
continue
if isinstance(x_values[0], np.ndarray):
x[category] = np.array(x_values)
proba = self.classifiers[category].predict_proba(x[category])
for i in range(len(proba)):
if proba[i][1] > self.threshold:
n_added += 1
c = constraints[category][i]
self.pool.remove(c)
solver.internal_solver.add_constraint(c.obj)
instance.found_violated_lazy_constraints += [c.cid]
logger.info("Added %d lazy constraints back into the model" % n_added)
logger.info("Lazy constraint pool has %d constraints" % len(self.pool))
def _collect_constraints(self, train_instances):
constraints = {}

View File

@@ -14,7 +14,7 @@ from miplearn.classifiers import Classifier
def test_usage_with_solver():
solver = Mock(spec=LearningSolver)
internal = solver.internal_solver = Mock(spec=InternalSolver)
internal.get_constraint_names = Mock(return_value=["c1", "c2", "c3", "c4"])
internal.get_constraint_ids = Mock(return_value=["c1", "c2", "c3", "c4"])
internal.extract_constraint = Mock(side_effect=lambda cid: "<%s>" % cid)
internal.is_constraint_satisfied = Mock(return_value=False)
@@ -59,7 +59,7 @@ def test_usage_with_solver():
instance.has_static_lazy_constraints.assert_called_once()
# Should ask internal solver for a list of constraints in the model
internal.get_constraint_names.assert_called_once()
internal.get_constraint_ids.assert_called_once()
# Should ask if each constraint in the model is lazy
instance.is_constraint_lazy.assert_has_calls([