mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Allow user to attach arbitrary data to violations
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional, Any, Set
|
||||
|
||||
@@ -36,7 +36,7 @@ class DynamicConstraintsComponent(Component):
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.classifiers: Dict[ConstraintCategory, Classifier] = {}
|
||||
self.thresholds: Dict[ConstraintCategory, Threshold] = {}
|
||||
self.known_cids: List[ConstraintName] = []
|
||||
self.known_violations: Dict[ConstraintName, Any] = {}
|
||||
self.attr = attr
|
||||
|
||||
def sample_xy_with_cids(
|
||||
@@ -48,18 +48,19 @@ class DynamicConstraintsComponent(Component):
|
||||
Dict[ConstraintCategory, List[List[bool]]],
|
||||
Dict[ConstraintCategory, List[ConstraintName]],
|
||||
]:
|
||||
if len(self.known_cids) == 0:
|
||||
if len(self.known_violations) == 0:
|
||||
return {}, {}, {}
|
||||
assert instance is not None
|
||||
x: Dict[ConstraintCategory, List[List[float]]] = {}
|
||||
y: Dict[ConstraintCategory, List[List[bool]]] = {}
|
||||
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
|
||||
known_cids = np.array(self.known_cids, dtype="S")
|
||||
known_cids = np.array(sorted(list(self.known_violations.keys())), dtype="S")
|
||||
|
||||
enforced_cids = None
|
||||
enforced_cids_np = sample.get_array(self.attr)
|
||||
if enforced_cids_np is not None:
|
||||
enforced_cids = list(enforced_cids_np)
|
||||
enforced_encoded = sample.get_scalar(self.attr)
|
||||
if enforced_encoded is not None:
|
||||
enforced = self.decode(enforced_encoded)
|
||||
enforced_cids = list(enforced.keys())
|
||||
|
||||
# Get user-provided constraint features
|
||||
(
|
||||
@@ -100,11 +101,10 @@ class DynamicConstraintsComponent(Component):
|
||||
@overrides
|
||||
def pre_fit(self, pre: List[Any]) -> None:
|
||||
assert pre is not None
|
||||
known_cids: Set = set()
|
||||
for cids in pre:
|
||||
known_cids |= set(list(cids))
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(known_cids))
|
||||
self.known_violations.clear()
|
||||
for violations in pre:
|
||||
for (vname, vdata) in violations.items():
|
||||
self.known_violations[vname] = vdata
|
||||
|
||||
def sample_predict(
|
||||
self,
|
||||
@@ -112,7 +112,7 @@ class DynamicConstraintsComponent(Component):
|
||||
sample: Sample,
|
||||
) -> List[ConstraintName]:
|
||||
pred: List[ConstraintName] = []
|
||||
if len(self.known_cids) == 0:
|
||||
if len(self.known_violations) == 0:
|
||||
logger.info("Classifiers not fitted. Skipping.")
|
||||
return pred
|
||||
x, _, cids = self.sample_xy_with_cids(instance, sample)
|
||||
@@ -131,7 +131,9 @@ class DynamicConstraintsComponent(Component):
|
||||
|
||||
@overrides
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||
return sample.get_array(self.attr)
|
||||
attr_encoded = sample.get_scalar(self.attr)
|
||||
assert attr_encoded is not None
|
||||
return self.decode(attr_encoded)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
@@ -153,11 +155,13 @@ class DynamicConstraintsComponent(Component):
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[str, float]:
|
||||
actual = sample.get_array(self.attr)
|
||||
assert actual is not None
|
||||
attr_encoded = sample.get_scalar(self.attr)
|
||||
assert attr_encoded is not None
|
||||
actual_violations = DynamicConstraintsComponent.decode(attr_encoded)
|
||||
actual = set(actual_violations.keys())
|
||||
pred = set(self.sample_predict(instance, sample))
|
||||
tp, tn, fp, fn = 0, 0, 0, 0
|
||||
for cid in self.known_cids:
|
||||
for cid in self.known_violations.keys():
|
||||
if cid in pred:
|
||||
if cid in actual:
|
||||
tp += 1
|
||||
@@ -169,3 +173,12 @@ class DynamicConstraintsComponent(Component):
|
||||
else:
|
||||
tn += 1
|
||||
return classifier_evaluation_dict(tp=tp, tn=tn, fp=fp, fn=fn)
|
||||
|
||||
@staticmethod
|
||||
def encode(violations: Dict[ConstraintName, Any]) -> str:
|
||||
return json.dumps({k.decode(): v for (k, v) in violations.items()})
|
||||
|
||||
@staticmethod
|
||||
def decode(violations_encoded: str) -> Dict[ConstraintName, Any]:
|
||||
violations = json.loads(violations_encoded)
|
||||
return {k.encode(): v for (k, v) in violations.items()}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import pdb
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -37,23 +35,23 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self.dynamic: DynamicConstraintsComponent = DynamicConstraintsComponent(
|
||||
classifier=classifier,
|
||||
threshold=threshold,
|
||||
attr="mip_constr_lazy_enforced",
|
||||
attr="mip_constr_lazy",
|
||||
)
|
||||
self.classifiers = self.dynamic.classifiers
|
||||
self.thresholds = self.dynamic.thresholds
|
||||
self.known_cids = self.dynamic.known_cids
|
||||
self.lazy_enforced: Set[ConstraintName] = set()
|
||||
self.known_violations = self.dynamic.known_violations
|
||||
self.lazy_enforced: Dict[ConstraintName, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def enforce(
|
||||
cids: List[ConstraintName],
|
||||
violations: Dict[ConstraintName, Any],
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
solver: "LearningSolver",
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
for cid in cids:
|
||||
instance.enforce_lazy_constraint(solver.internal_solver, model, cid)
|
||||
for (vname, vdata) in violations.items():
|
||||
instance.enforce_lazy_constraint(solver.internal_solver, model, vdata)
|
||||
|
||||
@overrides
|
||||
def before_solve_mip(
|
||||
@@ -66,9 +64,10 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
) -> None:
|
||||
self.lazy_enforced.clear()
|
||||
logger.info("Predicting violated (dynamic) lazy constraints...")
|
||||
cids = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d lazy constraints..." % len(cids))
|
||||
self.enforce(cids, instance, model, solver)
|
||||
vnames = self.dynamic.sample_predict(instance, sample)
|
||||
violations = {c: self.dynamic.known_violations[c] for c in vnames}
|
||||
logger.info("Enforcing %d lazy constraints..." % len(vnames))
|
||||
self.enforce(violations, instance, model, solver)
|
||||
|
||||
@overrides
|
||||
def after_solve_mip(
|
||||
@@ -79,10 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_array(
|
||||
"mip_constr_lazy_enforced",
|
||||
np.array(list(self.lazy_enforced), dtype="S"),
|
||||
)
|
||||
sample.put_scalar("mip_constr_lazy", self.dynamic.encode(self.lazy_enforced))
|
||||
|
||||
@overrides
|
||||
def iteration_cb(
|
||||
@@ -93,14 +89,17 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
) -> bool:
|
||||
assert solver.internal_solver is not None
|
||||
logger.debug("Finding violated lazy constraints...")
|
||||
cids = instance.find_violated_lazy_constraints(solver.internal_solver, model)
|
||||
if len(cids) == 0:
|
||||
violations = instance.find_violated_lazy_constraints(
|
||||
solver.internal_solver, model
|
||||
)
|
||||
if len(violations) == 0:
|
||||
logger.debug("No violations found")
|
||||
return False
|
||||
else:
|
||||
self.lazy_enforced |= set(cids)
|
||||
logger.debug(" %d violations found" % len(cids))
|
||||
self.enforce(cids, instance, model, solver)
|
||||
for v in violations:
|
||||
self.lazy_enforced[v] = violations[v]
|
||||
logger.debug(" %d violations found" % len(violations))
|
||||
self.enforce(violations, instance, model, solver)
|
||||
return True
|
||||
|
||||
# Delegate ML methods to self.dynamic
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING, Set, Tuple, Dict, List, Optional
|
||||
from typing import Any, TYPE_CHECKING, Tuple, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -32,9 +31,9 @@ class UserCutsComponent(Component):
|
||||
self.dynamic = DynamicConstraintsComponent(
|
||||
classifier=classifier,
|
||||
threshold=threshold,
|
||||
attr="mip_user_cuts_enforced",
|
||||
attr="mip_user_cuts",
|
||||
)
|
||||
self.enforced: Set[ConstraintName] = set()
|
||||
self.enforced: Dict[ConstraintName, Any] = {}
|
||||
self.n_added_in_callback = 0
|
||||
|
||||
@overrides
|
||||
@@ -50,11 +49,12 @@ class UserCutsComponent(Component):
|
||||
self.enforced.clear()
|
||||
self.n_added_in_callback = 0
|
||||
logger.info("Predicting violated user cuts...")
|
||||
cids = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
||||
for cid in cids:
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
stats["UserCuts: Added ahead-of-time"] = len(cids)
|
||||
vnames = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(vnames))
|
||||
for vname in vnames:
|
||||
vdata = self.dynamic.known_violations[vname]
|
||||
instance.enforce_user_cut(solver.internal_solver, model, vdata)
|
||||
stats["UserCuts: Added ahead-of-time"] = len(vnames)
|
||||
|
||||
@overrides
|
||||
def user_cut_cb(
|
||||
@@ -65,18 +65,17 @@ class UserCutsComponent(Component):
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
logger.debug("Finding violated user cuts...")
|
||||
cids = instance.find_violated_user_cuts(model)
|
||||
logger.debug(f"Found {len(cids)} violated user cuts")
|
||||
violations = instance.find_violated_user_cuts(model)
|
||||
logger.debug(f"Found {len(violations)} violated user cuts")
|
||||
logger.debug("Building violated user cuts...")
|
||||
for cid in cids:
|
||||
if cid in self.enforced:
|
||||
for (vname, vdata) in violations.items():
|
||||
if vname in self.enforced:
|
||||
continue
|
||||
assert isinstance(cid, ConstraintName)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
self.enforced.add(cid)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, vdata)
|
||||
self.enforced[vname] = vdata
|
||||
self.n_added_in_callback += 1
|
||||
if len(cids) > 0:
|
||||
logger.debug(f"Added {len(cids)} violated user cuts")
|
||||
if len(violations) > 0:
|
||||
logger.debug(f"Added {len(violations)} violated user cuts")
|
||||
|
||||
@overrides
|
||||
def after_solve_mip(
|
||||
@@ -87,10 +86,7 @@ class UserCutsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_array(
|
||||
"mip_user_cuts_enforced",
|
||||
np.array(list(self.enforced), dtype="S"),
|
||||
)
|
||||
sample.put_scalar("mip_user_cuts", self.dynamic.encode(self.enforced))
|
||||
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
||||
if self.n_added_in_callback > 0:
|
||||
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
||||
@@ -133,5 +129,5 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Dict[ConstraintCategory, Dict[str, float]]:
|
||||
) -> Dict[ConstraintCategory, Dict[ConstraintName, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
|
||||
import numpy as np
|
||||
|
||||
from miplearn.features.sample import Sample, MemorySample
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -114,7 +114,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, Any]:
|
||||
"""
|
||||
Returns lazy constraint violations found for the current solution.
|
||||
|
||||
@@ -124,40 +124,46 @@ class Instance(ABC):
|
||||
resolve the problem. The process repeats until no further lazy constraint
|
||||
violations are found.
|
||||
|
||||
Each "violation" is simply a string which allows the instance to identify
|
||||
unambiguously which lazy constraint should be generated. In the Traveling
|
||||
Salesman Problem, for example, a subtour violation could be a string
|
||||
containing the cities in the subtour.
|
||||
Violations should be returned in a dictionary mapping the name of the violation
|
||||
to some user-specified data that allows the instance to unambiguously generate
|
||||
the lazy constraints at a later time. In the Traveling Salesman Problem, for
|
||||
example, this function could return a dictionary identifying violated subtour
|
||||
inequalities. More concretely, it could return:
|
||||
{
|
||||
"s1": [1, 2, 3],
|
||||
"s2": [4, 5, 6, 7],
|
||||
}
|
||||
where "s1" and "s2" are the names of the subtours, and [1,2,3] and [4,5,6,7]
|
||||
are the cities in each subtour. The names of the violations should be kept
|
||||
stable across instances. In our example, "s1" should always correspond to
|
||||
[1,2,3] across all instances. The user-provided data should be picklable.
|
||||
|
||||
The current solution can be queried with `solver.get_solution()`. If the solver
|
||||
is configured to use lazy callbacks, this solution may be non-integer.
|
||||
|
||||
For a concrete example, see TravelingSalesmanInstance.
|
||||
"""
|
||||
return []
|
||||
return {}
|
||||
|
||||
def enforce_lazy_constraint(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Adds constraints to the model to ensure that the given violation is fixed.
|
||||
|
||||
This method is typically called immediately after
|
||||
find_violated_lazy_constraints. The violation object provided to this method
|
||||
is exactly the same object returned earlier by
|
||||
find_violated_lazy_constraints. After some training, LearningSolver may
|
||||
decide to proactively build some lazy constraints at the beginning of the
|
||||
optimization process, before a solution is even available. In this case,
|
||||
enforce_lazy_constraints will be called without a corresponding call to
|
||||
find_violated_lazy_constraints.
|
||||
`find_violated_lazy_constraints`. The argument `violation_data` is the
|
||||
user-provided data, previously returned by `find_violated_lazy_constraints`.
|
||||
In the Traveling Salesman Problem, for example, it could be a list of cities
|
||||
in the subtour.
|
||||
|
||||
Note that this method can be called either before the optimization starts or
|
||||
from within a callback. To ensure that constraints are added correctly in
|
||||
either case, it is recommended to use `solver.add_constraint`, instead of
|
||||
modifying the `model` object directly.
|
||||
After some training, LearningSolver may decide to proactively build some lazy
|
||||
constraints at the beginning of the optimization process, before a solution
|
||||
is even available. In this case, `enforce_lazy_constraints` will be called
|
||||
without a corresponding call to `find_violated_lazy_constraints`.
|
||||
|
||||
For a concrete example, see TravelingSalesmanInstance.
|
||||
"""
|
||||
@@ -166,14 +172,14 @@ class Instance(ABC):
|
||||
def has_user_cuts(self) -> bool:
|
||||
return False
|
||||
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
return []
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
return {}
|
||||
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
@@ -3,15 +3,15 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
import gc
|
||||
import os
|
||||
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
||||
import pickle
|
||||
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -71,7 +71,7 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -80,13 +80,13 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation_data)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -95,10 +95,10 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
self.instance.enforce_user_cut(solver, model, violation_data)
|
||||
|
||||
# Input & Output
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@@ -13,7 +13,7 @@ from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -83,7 +83,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -92,13 +92,13 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation_data)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -107,10 +107,10 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_name: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
self.instance.enforce_user_cut(solver, model, violation_name)
|
||||
|
||||
@overrides
|
||||
def load(self) -> None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import List, Tuple, FrozenSet, Any, Optional, Dict
|
||||
from typing import List, Tuple, Any, Optional, Dict
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -86,14 +86,15 @@ class TravelingSalesmanInstance(Instance):
|
||||
self,
|
||||
solver: InternalSolver,
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, List]:
|
||||
selected_edges = [e for e in self.edges if model.x[e].value > 0.5]
|
||||
graph = nx.Graph()
|
||||
graph.add_edges_from(selected_edges)
|
||||
violations = []
|
||||
violations = {}
|
||||
for c in list(nx.connected_components(graph)):
|
||||
if len(c) < self.n_cities:
|
||||
violations.append(",".join(map(str, c)).encode())
|
||||
cname = ("st[" + ",".join(map(str, c)) + "]").encode()
|
||||
violations[cname] = list(c)
|
||||
return violations
|
||||
|
||||
@overrides
|
||||
@@ -101,10 +102,9 @@ class TravelingSalesmanInstance(Instance):
|
||||
self,
|
||||
solver: InternalSolver,
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
component: List,
|
||||
) -> None:
|
||||
assert isinstance(solver, BasePyomoSolver)
|
||||
component = [int(v) for v in violation.decode().split(",")]
|
||||
cut_edges = [
|
||||
e
|
||||
for e in self.edges
|
||||
|
||||
@@ -710,7 +710,7 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack):
|
||||
self,
|
||||
solver: InternalSolver,
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
x0 = model.getVarByName("x[0]")
|
||||
model.cbLazy(x0 <= 0)
|
||||
|
||||
@@ -247,7 +247,7 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None:
|
||||
assert relsol is not None
|
||||
assert relsol[b"x[0]"] is not None
|
||||
if relsol[b"x[0]"] > 0:
|
||||
instance.enforce_lazy_constraint(cb_solver, cb_model, b"cut")
|
||||
instance.enforce_lazy_constraint(cb_solver, cb_model, None)
|
||||
|
||||
solver.set_instance(instance, model)
|
||||
solver.solve(lazy_cb=lazy_cb)
|
||||
|
||||
Reference in New Issue
Block a user