mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Allow user to attach arbitrary data to violations
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional, Any, Set
|
||||
|
||||
@@ -36,7 +36,7 @@ class DynamicConstraintsComponent(Component):
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[ConstraintCategory, Classifier] = {}
|
||||
self.thresholds: Dict[ConstraintCategory, Threshold] = {}
|
||||
self.known_cids: List[ConstraintName] = []
|
||||
self.known_violations: Dict[ConstraintName, Any] = {}
|
||||
self.attr = attr
|
||||
|
||||
def sample_xy_with_cids(
|
||||
@@ -48,18 +48,19 @@ class DynamicConstraintsComponent(Component):
|
||||
Dict[ConstraintCategory, List[List[bool]]],
|
||||
Dict[ConstraintCategory, List[ConstraintName]],
|
||||
]:
|
||||
if len(self.known_cids) == 0:
|
||||
if len(self.known_violations) == 0:
|
||||
return {}, {}, {}
|
||||
assert instance is not None
|
||||
x: Dict[ConstraintCategory, List[List[float]]] = {}
|
||||
y: Dict[ConstraintCategory, List[List[bool]]] = {}
|
||||
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
|
||||
known_cids = np.array(self.known_cids, dtype="S")
|
||||
known_cids = np.array(sorted(list(self.known_violations.keys())), dtype="S")
|
||||
|
||||
enforced_cids = None
|
||||
enforced_cids_np = sample.get_array(self.attr)
|
||||
if enforced_cids_np is not None:
|
||||
enforced_cids = list(enforced_cids_np)
|
||||
enforced_encoded = sample.get_scalar(self.attr)
|
||||
if enforced_encoded is not None:
|
||||
enforced = self.decode(enforced_encoded)
|
||||
enforced_cids = list(enforced.keys())
|
||||
|
||||
# Get user-provided constraint features
|
||||
(
|
||||
@@ -100,11 +101,10 @@ class DynamicConstraintsComponent(Component):
|
||||
@overrides
|
||||
def pre_fit(self, pre: List[Any]) -> None:
|
||||
assert pre is not None
|
||||
known_cids: Set = set()
|
||||
for cids in pre:
|
||||
known_cids |= set(list(cids))
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(known_cids))
|
||||
self.known_violations.clear()
|
||||
for violations in pre:
|
||||
for (vname, vdata) in violations.items():
|
||||
self.known_violations[vname] = vdata
|
||||
|
||||
def sample_predict(
|
||||
self,
|
||||
@@ -112,7 +112,7 @@ class DynamicConstraintsComponent(Component):
|
||||
sample: Sample,
|
||||
) -> List[ConstraintName]:
|
||||
pred: List[ConstraintName] = []
|
||||
if len(self.known_cids) == 0:
|
||||
if len(self.known_violations) == 0:
|
||||
logger.info("Classifiers not fitted. Skipping.")
|
||||
return pred
|
||||
x, _, cids = self.sample_xy_with_cids(instance, sample)
|
||||
@@ -131,7 +131,9 @@ class DynamicConstraintsComponent(Component):
|
||||
|
||||
@overrides
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||
return sample.get_array(self.attr)
|
||||
attr_encoded = sample.get_scalar(self.attr)
|
||||
assert attr_encoded is not None
|
||||
return self.decode(attr_encoded)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
@@ -153,11 +155,13 @@ class DynamicConstraintsComponent(Component):
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[str, float]:
|
||||
actual = sample.get_array(self.attr)
|
||||
assert actual is not None
|
||||
attr_encoded = sample.get_scalar(self.attr)
|
||||
assert attr_encoded is not None
|
||||
actual_violations = DynamicConstraintsComponent.decode(attr_encoded)
|
||||
actual = set(actual_violations.keys())
|
||||
pred = set(self.sample_predict(instance, sample))
|
||||
tp, tn, fp, fn = 0, 0, 0, 0
|
||||
for cid in self.known_cids:
|
||||
for cid in self.known_violations.keys():
|
||||
if cid in pred:
|
||||
if cid in actual:
|
||||
tp += 1
|
||||
@@ -169,3 +173,12 @@ class DynamicConstraintsComponent(Component):
|
||||
else:
|
||||
tn += 1
|
||||
return classifier_evaluation_dict(tp=tp, tn=tn, fp=fp, fn=fn)
|
||||
|
||||
@staticmethod
|
||||
def encode(violations: Dict[ConstraintName, Any]) -> str:
|
||||
return json.dumps({k.decode(): v for (k, v) in violations.items()})
|
||||
|
||||
@staticmethod
|
||||
def decode(violations_encoded: str) -> Dict[ConstraintName, Any]:
|
||||
violations = json.loads(violations_encoded)
|
||||
return {k.encode(): v for (k, v) in violations.items()}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import pdb
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -37,23 +35,23 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self.dynamic: DynamicConstraintsComponent = DynamicConstraintsComponent(
|
||||
classifier=classifier,
|
||||
threshold=threshold,
|
||||
attr="mip_constr_lazy_enforced",
|
||||
attr="mip_constr_lazy",
|
||||
)
|
||||
self.classifiers = self.dynamic.classifiers
|
||||
self.thresholds = self.dynamic.thresholds
|
||||
self.known_cids = self.dynamic.known_cids
|
||||
self.lazy_enforced: Set[ConstraintName] = set()
|
||||
self.known_violations = self.dynamic.known_violations
|
||||
self.lazy_enforced: Dict[ConstraintName, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def enforce(
|
||||
cids: List[ConstraintName],
|
||||
violations: Dict[ConstraintName, Any],
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
solver: "LearningSolver",
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
for cid in cids:
|
||||
instance.enforce_lazy_constraint(solver.internal_solver, model, cid)
|
||||
for (vname, vdata) in violations.items():
|
||||
instance.enforce_lazy_constraint(solver.internal_solver, model, vdata)
|
||||
|
||||
@overrides
|
||||
def before_solve_mip(
|
||||
@@ -66,9 +64,10 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
) -> None:
|
||||
self.lazy_enforced.clear()
|
||||
logger.info("Predicting violated (dynamic) lazy constraints...")
|
||||
cids = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d lazy constraints..." % len(cids))
|
||||
self.enforce(cids, instance, model, solver)
|
||||
vnames = self.dynamic.sample_predict(instance, sample)
|
||||
violations = {c: self.dynamic.known_violations[c] for c in vnames}
|
||||
logger.info("Enforcing %d lazy constraints..." % len(vnames))
|
||||
self.enforce(violations, instance, model, solver)
|
||||
|
||||
@overrides
|
||||
def after_solve_mip(
|
||||
@@ -79,10 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_array(
|
||||
"mip_constr_lazy_enforced",
|
||||
np.array(list(self.lazy_enforced), dtype="S"),
|
||||
)
|
||||
sample.put_scalar("mip_constr_lazy", self.dynamic.encode(self.lazy_enforced))
|
||||
|
||||
@overrides
|
||||
def iteration_cb(
|
||||
@@ -93,14 +89,17 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
) -> bool:
|
||||
assert solver.internal_solver is not None
|
||||
logger.debug("Finding violated lazy constraints...")
|
||||
cids = instance.find_violated_lazy_constraints(solver.internal_solver, model)
|
||||
if len(cids) == 0:
|
||||
violations = instance.find_violated_lazy_constraints(
|
||||
solver.internal_solver, model
|
||||
)
|
||||
if len(violations) == 0:
|
||||
logger.debug("No violations found")
|
||||
return False
|
||||
else:
|
||||
self.lazy_enforced |= set(cids)
|
||||
logger.debug(" %d violations found" % len(cids))
|
||||
self.enforce(cids, instance, model, solver)
|
||||
for v in violations:
|
||||
self.lazy_enforced[v] = violations[v]
|
||||
logger.debug(" %d violations found" % len(violations))
|
||||
self.enforce(violations, instance, model, solver)
|
||||
return True
|
||||
|
||||
# Delegate ML methods to self.dynamic
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
|
||||
from typing import Any, TYPE_CHECKING, Tuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -32,9 +31,9 @@ class UserCutsComponent(Component):
|
||||
self.dynamic = DynamicConstraintsComponent(
|
||||
classifier=classifier,
|
||||
threshold=threshold,
|
||||
attr="mip_user_cuts_enforced",
|
||||
attr="mip_user_cuts",
|
||||
)
|
||||
self.enforced: Set[ConstraintName] = set()
|
||||
self.enforced: Dict[ConstraintName, Any] = {}
|
||||
self.n_added_in_callback = 0
|
||||
|
||||
@overrides
|
||||
@@ -50,11 +49,12 @@ class UserCutsComponent(Component):
|
||||
self.enforced.clear()
|
||||
self.n_added_in_callback = 0
|
||||
logger.info("Predicting violated user cuts...")
|
||||
cids = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
||||
for cid in cids:
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
stats["UserCuts: Added ahead-of-time"] = len(cids)
|
||||
vnames = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(vnames))
|
||||
for vname in vnames:
|
||||
vdata = self.dynamic.known_violations[vname]
|
||||
instance.enforce_user_cut(solver.internal_solver, model, vdata)
|
||||
stats["UserCuts: Added ahead-of-time"] = len(vnames)
|
||||
|
||||
@overrides
|
||||
def user_cut_cb(
|
||||
@@ -65,18 +65,17 @@ class UserCutsComponent(Component):
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
logger.debug("Finding violated user cuts...")
|
||||
cids = instance.find_violated_user_cuts(model)
|
||||
logger.debug(f"Found {len(cids)} violated user cuts")
|
||||
violations = instance.find_violated_user_cuts(model)
|
||||
logger.debug(f"Found {len(violations)} violated user cuts")
|
||||
logger.debug("Building violated user cuts...")
|
||||
for cid in cids:
|
||||
if cid in self.enforced:
|
||||
for (vname, vdata) in violations.items():
|
||||
if vname in self.enforced:
|
||||
continue
|
||||
assert isinstance(cid, ConstraintName)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
self.enforced.add(cid)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, vdata)
|
||||
self.enforced[vname] = vdata
|
||||
self.n_added_in_callback += 1
|
||||
if len(cids) > 0:
|
||||
logger.debug(f"Added {len(cids)} violated user cuts")
|
||||
if len(violations) > 0:
|
||||
logger.debug(f"Added {len(violations)} violated user cuts")
|
||||
|
||||
@overrides
|
||||
def after_solve_mip(
|
||||
@@ -87,10 +86,7 @@ class UserCutsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_array(
|
||||
"mip_user_cuts_enforced",
|
||||
np.array(list(self.enforced), dtype="S"),
|
||||
)
|
||||
sample.put_scalar("mip_user_cuts", self.dynamic.encode(self.enforced))
|
||||
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
||||
if self.n_added_in_callback > 0:
|
||||
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
||||
@@ -133,5 +129,5 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Dict[ConstraintCategory, Dict[str, float]]:
|
||||
) -> Dict[ConstraintCategory, Dict[ConstraintName, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
Reference in New Issue
Block a user