Allow user to attach arbitrary data to violations

This commit is contained in:
2022-01-25 11:39:03 -06:00
parent ba8f5bb2f4
commit 2a76dd42ec
12 changed files with 168 additions and 127 deletions

View File

@@ -9,7 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
import numpy as np
from miplearn.features.sample import Sample, MemorySample
from miplearn.types import ConstraintName, ConstraintCategory
from miplearn.types import ConstraintName
logger = logging.getLogger(__name__)
@@ -114,7 +114,7 @@ class Instance(ABC):
self,
solver: "InternalSolver",
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, Any]:
"""
Returns lazy constraint violations found for the current solution.
@@ -124,40 +124,46 @@ class Instance(ABC):
resolve the problem. The process repeats until no further lazy constraint
violations are found.
Each "violation" is simply a string which allows the instance to identify
unambiguously which lazy constraint should be generated. In the Traveling
Salesman Problem, for example, a subtour violation could be a string
containing the cities in the subtour.
Violations should be returned in a dictionary mapping the name of the violation
to some user-specified data that allows the instance to unambiguously generate
the lazy constraints at a later time. In the Traveling Salesman Problem, for
example, this function could return a dictionary identifying violated subtour
inequalities. More concretely, it could return:
{
"s1": [1, 2, 3],
"s2": [4, 5, 6, 7],
}
where "s1" and "s2" are the names of the subtours, and [1,2,3] and [4,5,6,7]
are the cities in each subtour. The names of the violations should be kept
stable across instances. In our example, "s1" should always correspond to
[1,2,3] across all instances. The user-provided data should be picklable.
The current solution can be queried with `solver.get_solution()`. If the solver
is configured to use lazy callbacks, this solution may be non-integer.
For a concrete example, see TravelingSalesmanInstance.
"""
return []
return {}
def enforce_lazy_constraint(
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
"""
Adds constraints to the model to ensure that the given violation is fixed.
This method is typically called immediately after
find_violated_lazy_constraints. The violation object provided to this method
is exactly the same object returned earlier by
find_violated_lazy_constraints. After some training, LearningSolver may
decide to proactively build some lazy constraints at the beginning of the
optimization process, before a solution is even available. In this case,
enforce_lazy_constraints will be called without a corresponding call to
find_violated_lazy_constraints.
`find_violated_lazy_constraints`. The argument `violation_data` is the
user-provided data, previously returned by `find_violated_lazy_constraints`.
In the Traveling Salesman Problem, for example, it could be a list of cities
in the subtour.
Note that this method can be called either before the optimization starts or
from within a callback. To ensure that constraints are added correctly in
either case, it is recommended to use `solver.add_constraint`, instead of
modifying the `model` object directly.
After some training, LearningSolver may decide to proactively build some lazy
constraints at the beginning of the optimization process, before a solution
is even available. In this case, `enforce_lazy_constraints` will be called
without a corresponding call to `find_violated_lazy_constraints`.
For a concrete example, see TravelingSalesmanInstance.
"""
@@ -166,14 +172,14 @@ class Instance(ABC):
def has_user_cuts(self) -> bool:
return False
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
return []
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
return {}
def enforce_user_cut(
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> Any:
return None

View File

@@ -3,15 +3,15 @@
# Released under the modified BSD license. See COPYING.md for more details.
import gc
import os
from typing import Any, Optional, List, Dict, TYPE_CHECKING
import pickle
from typing import Any, Optional, List, Dict, TYPE_CHECKING
import numpy as np
from overrides import overrides
from miplearn.features.sample import Hdf5Sample, Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName, ConstraintCategory
from miplearn.types import ConstraintName
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -71,7 +71,7 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(solver, model)
@@ -80,13 +80,13 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_lazy_constraint(solver, model, violation)
self.instance.enforce_lazy_constraint(solver, model, violation_data)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@@ -95,10 +95,10 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_user_cut(solver, model, violation)
self.instance.enforce_user_cut(solver, model, violation_data)
# Input & Output
# -------------------------------------------------------------------------

View File

@@ -13,7 +13,7 @@ from overrides import overrides
from miplearn.features.sample import Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName, ConstraintCategory
from miplearn.types import ConstraintName
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -83,7 +83,7 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
) -> List[ConstraintName]:
) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(solver, model)
@@ -92,13 +92,13 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_data: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_lazy_constraint(solver, model, violation)
self.instance.enforce_lazy_constraint(solver, model, violation_data)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@@ -107,10 +107,10 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: ConstraintName,
violation_name: Any,
) -> None:
assert self.instance is not None
self.instance.enforce_user_cut(solver, model, violation)
self.instance.enforce_user_cut(solver, model, violation_name)
@overrides
def load(self) -> None: