Allow user to attach arbitrary data to violations

This commit is contained in:
2022-01-25 11:39:03 -06:00
parent ba8f5bb2f4
commit 2a76dd42ec
12 changed files with 168 additions and 127 deletions

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import json
import logging
from typing import Dict, List, Tuple, Optional, Any, Set
@@ -36,7 +36,7 @@ class DynamicConstraintsComponent(Component):
self.classifier_prototype: Classifier = classifier
self.classifiers: Dict[ConstraintCategory, Classifier] = {}
self.thresholds: Dict[ConstraintCategory, Threshold] = {}
self.known_cids: List[ConstraintName] = []
self.known_violations: Dict[ConstraintName, Any] = {}
self.attr = attr
def sample_xy_with_cids(
@@ -48,18 +48,19 @@ class DynamicConstraintsComponent(Component):
Dict[ConstraintCategory, List[List[bool]]],
Dict[ConstraintCategory, List[ConstraintName]],
]:
if len(self.known_cids) == 0:
if len(self.known_violations) == 0:
return {}, {}, {}
assert instance is not None
x: Dict[ConstraintCategory, List[List[float]]] = {}
y: Dict[ConstraintCategory, List[List[bool]]] = {}
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
known_cids = np.array(self.known_cids, dtype="S")
known_cids = np.array(sorted(list(self.known_violations.keys())), dtype="S")
enforced_cids = None
enforced_cids_np = sample.get_array(self.attr)
if enforced_cids_np is not None:
enforced_cids = list(enforced_cids_np)
enforced_encoded = sample.get_scalar(self.attr)
if enforced_encoded is not None:
enforced = self.decode(enforced_encoded)
enforced_cids = list(enforced.keys())
# Get user-provided constraint features
(
@@ -100,11 +101,10 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_fit(self, pre: List[Any]) -> None:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= set(list(cids))
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
self.known_violations.clear()
for violations in pre:
for (vname, vdata) in violations.items():
self.known_violations[vname] = vdata
def sample_predict(
self,
@@ -112,7 +112,7 @@ class DynamicConstraintsComponent(Component):
sample: Sample,
) -> List[ConstraintName]:
pred: List[ConstraintName] = []
if len(self.known_cids) == 0:
if len(self.known_violations) == 0:
logger.info("Classifiers not fitted. Skipping.")
return pred
x, _, cids = self.sample_xy_with_cids(instance, sample)
@@ -131,7 +131,9 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return sample.get_array(self.attr)
attr_encoded = sample.get_scalar(self.attr)
assert attr_encoded is not None
return self.decode(attr_encoded)
@overrides
def fit_xy(
@@ -153,11 +155,13 @@ class DynamicConstraintsComponent(Component):
instance: Instance,
sample: Sample,
) -> Dict[str, float]:
actual = sample.get_array(self.attr)
assert actual is not None
attr_encoded = sample.get_scalar(self.attr)
assert attr_encoded is not None
actual_violations = DynamicConstraintsComponent.decode(attr_encoded)
actual = set(actual_violations.keys())
pred = set(self.sample_predict(instance, sample))
tp, tn, fp, fn = 0, 0, 0, 0
for cid in self.known_cids:
for cid in self.known_violations.keys():
if cid in pred:
if cid in actual:
tp += 1
@@ -169,3 +173,12 @@ class DynamicConstraintsComponent(Component):
else:
tn += 1
return classifier_evaluation_dict(tp=tp, tn=tn, fp=fp, fn=fn)
@staticmethod
def encode(violations: Dict[ConstraintName, Any]) -> str:
return json.dumps({k.decode(): v for (k, v) in violations.items()})
@staticmethod
def decode(violations_encoded: str) -> Dict[ConstraintName, Any]:
violations = json.loads(violations_encoded)
return {k.encode(): v for (k, v) in violations.items()}

View File

@@ -1,10 +1,8 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
import pdb
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional
import numpy as np
from overrides import overrides
@@ -37,23 +35,23 @@ class DynamicLazyConstraintsComponent(Component):
self.dynamic: DynamicConstraintsComponent = DynamicConstraintsComponent(
classifier=classifier,
threshold=threshold,
attr="mip_constr_lazy_enforced",
attr="mip_constr_lazy",
)
self.classifiers = self.dynamic.classifiers
self.thresholds = self.dynamic.thresholds
self.known_cids = self.dynamic.known_cids
self.lazy_enforced: Set[ConstraintName] = set()
self.known_violations = self.dynamic.known_violations
self.lazy_enforced: Dict[ConstraintName, Any] = {}
@staticmethod
def enforce(
cids: List[ConstraintName],
violations: Dict[ConstraintName, Any],
instance: Instance,
model: Any,
solver: "LearningSolver",
) -> None:
assert solver.internal_solver is not None
for cid in cids:
instance.enforce_lazy_constraint(solver.internal_solver, model, cid)
for (vname, vdata) in violations.items():
instance.enforce_lazy_constraint(solver.internal_solver, model, vdata)
@overrides
def before_solve_mip(
@@ -66,9 +64,10 @@ class DynamicLazyConstraintsComponent(Component):
) -> None:
self.lazy_enforced.clear()
logger.info("Predicting violated (dynamic) lazy constraints...")
cids = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d lazy constraints..." % len(cids))
self.enforce(cids, instance, model, solver)
vnames = self.dynamic.sample_predict(instance, sample)
violations = {c: self.dynamic.known_violations[c] for c in vnames}
logger.info("Enforcing %d lazy constraints..." % len(vnames))
self.enforce(violations, instance, model, solver)
@overrides
def after_solve_mip(
@@ -79,10 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put_array(
"mip_constr_lazy_enforced",
np.array(list(self.lazy_enforced), dtype="S"),
)
sample.put_scalar("mip_constr_lazy", self.dynamic.encode(self.lazy_enforced))
@overrides
def iteration_cb(
@@ -93,14 +89,17 @@ class DynamicLazyConstraintsComponent(Component):
) -> bool:
assert solver.internal_solver is not None
logger.debug("Finding violated lazy constraints...")
cids = instance.find_violated_lazy_constraints(solver.internal_solver, model)
if len(cids) == 0:
violations = instance.find_violated_lazy_constraints(
solver.internal_solver, model
)
if len(violations) == 0:
logger.debug("No violations found")
return False
else:
self.lazy_enforced |= set(cids)
logger.debug(" %d violations found" % len(cids))
self.enforce(cids, instance, model, solver)
for v in violations:
self.lazy_enforced[v] = violations[v]
logger.debug(" %d violations found" % len(violations))
self.enforce(violations, instance, model, solver)
return True
# Delegate ML methods to self.dynamic

View File

@@ -1,9 +1,8 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
from typing import Any, TYPE_CHECKING, Tuple, Dict, List
import numpy as np
from overrides import overrides
@@ -32,9 +31,9 @@ class UserCutsComponent(Component):
self.dynamic = DynamicConstraintsComponent(
classifier=classifier,
threshold=threshold,
attr="mip_user_cuts_enforced",
attr="mip_user_cuts",
)
self.enforced: Set[ConstraintName] = set()
self.enforced: Dict[ConstraintName, Any] = {}
self.n_added_in_callback = 0
@overrides
@@ -50,11 +49,12 @@ class UserCutsComponent(Component):
self.enforced.clear()
self.n_added_in_callback = 0
logger.info("Predicting violated user cuts...")
cids = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
for cid in cids:
instance.enforce_user_cut(solver.internal_solver, model, cid)
stats["UserCuts: Added ahead-of-time"] = len(cids)
vnames = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(vnames))
for vname in vnames:
vdata = self.dynamic.known_violations[vname]
instance.enforce_user_cut(solver.internal_solver, model, vdata)
stats["UserCuts: Added ahead-of-time"] = len(vnames)
@overrides
def user_cut_cb(
@@ -65,18 +65,17 @@ class UserCutsComponent(Component):
) -> None:
assert solver.internal_solver is not None
logger.debug("Finding violated user cuts...")
cids = instance.find_violated_user_cuts(model)
logger.debug(f"Found {len(cids)} violated user cuts")
violations = instance.find_violated_user_cuts(model)
logger.debug(f"Found {len(violations)} violated user cuts")
logger.debug("Building violated user cuts...")
for cid in cids:
if cid in self.enforced:
for (vname, vdata) in violations.items():
if vname in self.enforced:
continue
assert isinstance(cid, ConstraintName)
instance.enforce_user_cut(solver.internal_solver, model, cid)
self.enforced.add(cid)
instance.enforce_user_cut(solver.internal_solver, model, vdata)
self.enforced[vname] = vdata
self.n_added_in_callback += 1
if len(cids) > 0:
logger.debug(f"Added {len(cids)} violated user cuts")
if len(violations) > 0:
logger.debug(f"Added {len(violations)} violated user cuts")
@overrides
def after_solve_mip(
@@ -87,10 +86,7 @@ class UserCutsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put_array(
"mip_user_cuts_enforced",
np.array(list(self.enforced), dtype="S"),
)
sample.put_scalar("mip_user_cuts", self.dynamic.encode(self.enforced))
stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
@@ -133,5 +129,5 @@ class UserCutsComponent(Component):
self,
instance: "Instance",
sample: Sample,
) -> Dict[ConstraintCategory, Dict[str, float]]:
) -> Dict[ConstraintCategory, Dict[ConstraintName, float]]:
return self.dynamic.sample_evaluate(instance, sample)