Allow user to attach arbitrary data to violations

This commit is contained in:
2022-01-25 11:39:03 -06:00
parent ba8f5bb2f4
commit 2a76dd42ec
12 changed files with 168 additions and 127 deletions

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import json
import logging
from typing import Dict, List, Tuple, Optional, Any, Set
@@ -36,7 +36,7 @@ class DynamicConstraintsComponent(Component):
self.classifier_prototype: Classifier = classifier
self.classifiers: Dict[ConstraintCategory, Classifier] = {}
self.thresholds: Dict[ConstraintCategory, Threshold] = {}
self.known_cids: List[ConstraintName] = []
self.known_violations: Dict[ConstraintName, Any] = {}
self.attr = attr
def sample_xy_with_cids(
@@ -48,18 +48,19 @@ class DynamicConstraintsComponent(Component):
Dict[ConstraintCategory, List[List[bool]]],
Dict[ConstraintCategory, List[ConstraintName]],
]:
if len(self.known_cids) == 0:
if len(self.known_violations) == 0:
return {}, {}, {}
assert instance is not None
x: Dict[ConstraintCategory, List[List[float]]] = {}
y: Dict[ConstraintCategory, List[List[bool]]] = {}
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
known_cids = np.array(self.known_cids, dtype="S")
known_cids = np.array(sorted(list(self.known_violations.keys())), dtype="S")
enforced_cids = None
enforced_cids_np = sample.get_array(self.attr)
if enforced_cids_np is not None:
enforced_cids = list(enforced_cids_np)
enforced_encoded = sample.get_scalar(self.attr)
if enforced_encoded is not None:
enforced = self.decode(enforced_encoded)
enforced_cids = list(enforced.keys())
# Get user-provided constraint features
(
@@ -100,11 +101,10 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_fit(self, pre: List[Any]) -> None:
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= set(list(cids))
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
self.known_violations.clear()
for violations in pre:
for (vname, vdata) in violations.items():
self.known_violations[vname] = vdata
def sample_predict(
self,
@@ -112,7 +112,7 @@ class DynamicConstraintsComponent(Component):
sample: Sample,
) -> List[ConstraintName]:
pred: List[ConstraintName] = []
if len(self.known_cids) == 0:
if len(self.known_violations) == 0:
logger.info("Classifiers not fitted. Skipping.")
return pred
x, _, cids = self.sample_xy_with_cids(instance, sample)
@@ -131,7 +131,9 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return sample.get_array(self.attr)
attr_encoded = sample.get_scalar(self.attr)
assert attr_encoded is not None
return self.decode(attr_encoded)
@overrides
def fit_xy(
@@ -153,11 +155,13 @@ class DynamicConstraintsComponent(Component):
instance: Instance,
sample: Sample,
) -> Dict[str, float]:
actual = sample.get_array(self.attr)
assert actual is not None
attr_encoded = sample.get_scalar(self.attr)
assert attr_encoded is not None
actual_violations = DynamicConstraintsComponent.decode(attr_encoded)
actual = set(actual_violations.keys())
pred = set(self.sample_predict(instance, sample))
tp, tn, fp, fn = 0, 0, 0, 0
for cid in self.known_cids:
for cid in self.known_violations.keys():
if cid in pred:
if cid in actual:
tp += 1
@@ -169,3 +173,12 @@ class DynamicConstraintsComponent(Component):
else:
tn += 1
return classifier_evaluation_dict(tp=tp, tn=tn, fp=fp, fn=fn)
@staticmethod
def encode(violations: Dict[ConstraintName, Any]) -> str:
return json.dumps({k.decode(): v for (k, v) in violations.items()})
@staticmethod
def decode(violations_encoded: str) -> Dict[ConstraintName, Any]:
violations = json.loads(violations_encoded)
return {k.encode(): v for (k, v) in violations.items()}

View File

@@ -1,10 +1,8 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
import pdb
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional
import numpy as np
from overrides import overrides
@@ -37,23 +35,23 @@ class DynamicLazyConstraintsComponent(Component):
self.dynamic: DynamicConstraintsComponent = DynamicConstraintsComponent(
classifier=classifier,
threshold=threshold,
attr="mip_constr_lazy_enforced",
attr="mip_constr_lazy",
)
self.classifiers = self.dynamic.classifiers
self.thresholds = self.dynamic.thresholds
self.known_cids = self.dynamic.known_cids
self.lazy_enforced: Set[ConstraintName] = set()
self.known_violations = self.dynamic.known_violations
self.lazy_enforced: Dict[ConstraintName, Any] = {}
@staticmethod
def enforce(
cids: List[ConstraintName],
violations: Dict[ConstraintName, Any],
instance: Instance,
model: Any,
solver: "LearningSolver",
) -> None:
assert solver.internal_solver is not None
for cid in cids:
instance.enforce_lazy_constraint(solver.internal_solver, model, cid)
for (vname, vdata) in violations.items():
instance.enforce_lazy_constraint(solver.internal_solver, model, vdata)
@overrides
def before_solve_mip(
@@ -66,9 +64,10 @@ class DynamicLazyConstraintsComponent(Component):
) -> None:
self.lazy_enforced.clear()
logger.info("Predicting violated (dynamic) lazy constraints...")
cids = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d lazy constraints..." % len(cids))
self.enforce(cids, instance, model, solver)
vnames = self.dynamic.sample_predict(instance, sample)
violations = {c: self.dynamic.known_violations[c] for c in vnames}
logger.info("Enforcing %d lazy constraints..." % len(vnames))
self.enforce(violations, instance, model, solver)
@overrides
def after_solve_mip(
@@ -79,10 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put_array(
"mip_constr_lazy_enforced",
np.array(list(self.lazy_enforced), dtype="S"),
)
sample.put_scalar("mip_constr_lazy", self.dynamic.encode(self.lazy_enforced))
@overrides
def iteration_cb(
@@ -93,14 +89,17 @@ class DynamicLazyConstraintsComponent(Component):
) -> bool:
assert solver.internal_solver is not None
logger.debug("Finding violated lazy constraints...")
cids = instance.find_violated_lazy_constraints(solver.internal_solver, model)
if len(cids) == 0:
violations = instance.find_violated_lazy_constraints(
solver.internal_solver, model
)
if len(violations) == 0:
logger.debug("No violations found")
return False
else:
self.lazy_enforced |= set(cids)
logger.debug(" %d violations found" % len(cids))
self.enforce(cids, instance, model, solver)
for v in violations:
self.lazy_enforced[v] = violations[v]
logger.debug(" %d violations found" % len(violations))
self.enforce(violations, instance, model, solver)
return True
# Delegate ML methods to self.dynamic

View File

@@ -1,9 +1,8 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
from typing import Any, TYPE_CHECKING, Tuple, Dict, List
import numpy as np
from overrides import overrides
@@ -32,9 +31,9 @@ class UserCutsComponent(Component):
self.dynamic = DynamicConstraintsComponent(
classifier=classifier,
threshold=threshold,
attr="mip_user_cuts_enforced",
attr="mip_user_cuts",
)
self.enforced: Set[ConstraintName] = set()
self.enforced: Dict[ConstraintName, Any] = {}
self.n_added_in_callback = 0
@overrides
@@ -50,11 +49,12 @@ class UserCutsComponent(Component):
self.enforced.clear()
self.n_added_in_callback = 0
logger.info("Predicting violated user cuts...")
cids = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
for cid in cids:
instance.enforce_user_cut(solver.internal_solver, model, cid)
stats["UserCuts: Added ahead-of-time"] = len(cids)
vnames = self.dynamic.sample_predict(instance, sample)
logger.info("Enforcing %d user cuts ahead-of-time..." % len(vnames))
for vname in vnames:
vdata = self.dynamic.known_violations[vname]
instance.enforce_user_cut(solver.internal_solver, model, vdata)
stats["UserCuts: Added ahead-of-time"] = len(vnames)
@overrides
def user_cut_cb(
@@ -65,18 +65,17 @@ class UserCutsComponent(Component):
) -> None:
assert solver.internal_solver is not None
logger.debug("Finding violated user cuts...")
cids = instance.find_violated_user_cuts(model)
logger.debug(f"Found {len(cids)} violated user cuts")
violations = instance.find_violated_user_cuts(model)
logger.debug(f"Found {len(violations)} violated user cuts")
logger.debug("Building violated user cuts...")
for cid in cids:
if cid in self.enforced:
for (vname, vdata) in violations.items():
if vname in self.enforced:
continue
assert isinstance(cid, ConstraintName)
instance.enforce_user_cut(solver.internal_solver, model, cid)
self.enforced.add(cid)
instance.enforce_user_cut(solver.internal_solver, model, vdata)
self.enforced[vname] = vdata
self.n_added_in_callback += 1
if len(cids) > 0:
logger.debug(f"Added {len(cids)} violated user cuts")
if len(violations) > 0:
logger.debug(f"Added {len(violations)} violated user cuts")
@overrides
def after_solve_mip(
@@ -87,10 +86,7 @@ class UserCutsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put_array(
"mip_user_cuts_enforced",
np.array(list(self.enforced), dtype="S"),
)
sample.put_scalar("mip_user_cuts", self.dynamic.encode(self.enforced))
stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
@@ -133,5 +129,5 @@ class UserCutsComponent(Component):
self,
instance: "Instance",
sample: Sample,
) -> Dict[ConstraintCategory, Dict[str, float]]:
) -> Dict[ConstraintCategory, Dict[ConstraintName, float]]:
return self.dynamic.sample_evaluate(instance, sample)

View File

@@ -9,7 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
import numpy as np
from miplearn.features.sample import Sample, MemorySample
from miplearn.types import ConstraintName, ConstraintCategory
from miplearn.types import ConstraintName
logger = logging.getLogger(__name__)
@@ -114,7 +114,7 @@ class Instance(ABC):
self,
solver: "InternalSolver",
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, Any]:
"""
Returns lazy constraint violations found for the current solution.
@@ -124,40 +124,46 @@ class Instance(ABC):
resolve the problem. The process repeats until no further lazy constraint
violations are found.
Each "violation" is simply a string which allows the instance to identify
unambiguously which lazy constraint should be generated. In the Traveling
Salesman Problem, for example, a subtour violation could be a string
containing the cities in the subtour.
Violations should be returned in a dictionary mapping the name of the violation
to some user-specified data that allows the instance to unambiguously generate
the lazy constraints at a later time. In the Traveling Salesman Problem, for
example, this function could return a dictionary identifying violated subtour
inequalities. More concretely, it could return:
{
"s1": [1, 2, 3],
"s2": [4, 5, 6, 7],
}
where "s1" and "s2" are the names of the subtours, and [1,2,3] and [4,5,6,7]
are the cities in each subtour. The names of the violations should be kept
stable across instances. In our example, "s1" should always correspond to
[1,2,3] across all instances. The user-provided data should be picklable.
The current solution can be queried with `solver.get_solution()`. If the solver
is configured to use lazy callbacks, this solution may be non-integer.
For a concrete example, see TravelingSalesmanInstance.
"""
return []
return {}
def enforce_lazy_constraint(
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
"""
Adds constraints to the model to ensure that the given violation is fixed.
This method is typically called immediately after
find_violated_lazy_constraints. The violation object provided to this method
is exactly the same object returned earlier by
find_violated_lazy_constraints. After some training, LearningSolver may
decide to proactively build some lazy constraints at the beginning of the
optimization process, before a solution is even available. In this case,
enforce_lazy_constraints will be called without a corresponding call to
find_violated_lazy_constraints.
`find_violated_lazy_constraints`. The argument `violation_data` is the
user-provided data, previously returned by `find_violated_lazy_constraints`.
In the Traveling Salesman Problem, for example, it could be a list of cities
in the subtour.
Note that this method can be called either before the optimization starts or
from within a callback. To ensure that constraints are added correctly in
either case, it is recommended to use `solver.add_constraint`, instead of
modifying the `model` object directly.
After some training, LearningSolver may decide to proactively build some lazy
constraints at the beginning of the optimization process, before a solution
is even available. In this case, `enforce_lazy_constraints` will be called
without a corresponding call to `find_violated_lazy_constraints`.
For a concrete example, see TravelingSalesmanInstance.
"""
@@ -166,14 +172,14 @@ class Instance(ABC):
def has_user_cuts(self) -> bool:
return False
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
return []
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
return {}
def enforce_user_cut(
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> Any:
return None

View File

@@ -3,15 +3,15 @@
# Released under the modified BSD license. See COPYING.md for more details.
import gc
import os
from typing import Any, Optional, List, Dict, TYPE_CHECKING
import pickle
from typing import Any, Optional, List, Dict, TYPE_CHECKING
import numpy as np
from overrides import overrides
from miplearn.features.sample import Hdf5Sample, Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName, ConstraintCategory
from miplearn.types import ConstraintName
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -71,7 +71,7 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(solver, model)
@@ -80,13 +80,13 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_lazy_constraint(solver, model, violation)
self.instance.enforce_lazy_constraint(solver, model, violation_data)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@@ -95,10 +95,10 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_user_cut(solver, model, violation)
self.instance.enforce_user_cut(solver, model, violation_data)
# Input & Output
# -------------------------------------------------------------------------

View File

@@ -13,7 +13,7 @@ from overrides import overrides
from miplearn.features.sample import Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName, ConstraintCategory
from miplearn.types import ConstraintName
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -83,7 +83,7 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(solver, model)
@@ -92,13 +92,13 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_lazy_constraint(solver, model, violation)
self.instance.enforce_lazy_constraint(solver, model, violation_data)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@@ -107,10 +107,10 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_name: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_user_cut(solver, model, violation)
self.instance.enforce_user_cut(solver, model, violation_name)
@overrides
def load(self) -> None:

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import List, Tuple, FrozenSet, Any, Optional, Dict
from typing import List, Tuple, Any, Optional, Dict
import networkx as nx
import numpy as np
@@ -86,14 +86,15 @@ class TravelingSalesmanInstance(Instance):
self,
solver: InternalSolver,
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, List]:
selected_edges = [e for e in self.edges if model.x[e].value > 0.5]
graph = nx.Graph()
graph.add_edges_from(selected_edges)
violations = []
violations = {}
for c in list(nx.connected_components(graph)):
if len(c) < self.n_cities:
violations.append(",".join(map(str, c)).encode())
cname = ("st[" + ",".join(map(str, c)) + "]").encode()
violations[cname] = list(c)
return violations
@overrides
@@ -101,10 +102,9 @@ class TravelingSalesmanInstance(Instance):
self,
solver: InternalSolver,
model: Any,
violation: ConstraintName,
component: List,
) -> None:
assert isinstance(solver, BasePyomoSolver)
component = [int(v) for v in violation.decode().split(",")]
cut_edges = [
e
for e in self.edges

View File

@@ -710,7 +710,7 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack):
self,
solver: InternalSolver,
model: Any,
violation: str,
violation_data: Any,
) -> None:
x0 = model.getVarByName("x[0]")
model.cbLazy(x0 <= 0)

View File

@@ -247,7 +247,7 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None:
assert relsol is not None
assert relsol[b"x[0]"] is not None
if relsol[b"x[0]"] > 0:
instance.enforce_lazy_constraint(cb_solver, cb_model, b"cut")
instance.enforce_lazy_constraint(cb_solver, cb_model, None)
solver.set_instance(instance, model)
solver.solve(lazy_cb=lazy_cb)