mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 01:48:51 -06:00
Allow user to attach arbitrary data to violations
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
|
||||
import numpy as np
|
||||
|
||||
from miplearn.features.sample import Sample, MemorySample
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -114,7 +114,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, Any]:
|
||||
"""
|
||||
Returns lazy constraint violations found for the current solution.
|
||||
|
||||
@@ -124,40 +124,46 @@ class Instance(ABC):
|
||||
resolve the problem. The process repeats until no further lazy constraint
|
||||
violations are found.
|
||||
|
||||
Each "violation" is simply a string which allows the instance to identify
|
||||
unambiguously which lazy constraint should be generated. In the Traveling
|
||||
Salesman Problem, for example, a subtour violation could be a string
|
||||
containing the cities in the subtour.
|
||||
Violations should be returned in a dictionary mapping the name of the violation
|
||||
to some user-specified data that allows the instance to unambiguously generate
|
||||
the lazy constraints at a later time. In the Traveling Salesman Problem, for
|
||||
example, this function could return a dictionary identifying violated subtour
|
||||
inequalities. More concretely, it could return:
|
||||
{
|
||||
"s1": [1, 2, 3],
|
||||
"s2": [4, 5, 6, 7],
|
||||
}
|
||||
where "s1" and "s2" are the names of the subtours, and [1,2,3] and [4,5,6,7]
|
||||
are the cities in each subtour. The names of the violations should be kept
|
||||
stable across instances. In our example, "s1" should always correspond to
|
||||
[1,2,3] across all instances. The user-provided data should be picklable.
|
||||
|
||||
The current solution can be queried with `solver.get_solution()`. If the solver
|
||||
is configured to use lazy callbacks, this solution may be non-integer.
|
||||
|
||||
For a concrete example, see TravelingSalesmanInstance.
|
||||
"""
|
||||
return []
|
||||
return {}
|
||||
|
||||
def enforce_lazy_constraint(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Adds constraints to the model to ensure that the given violation is fixed.
|
||||
|
||||
This method is typically called immediately after
|
||||
find_violated_lazy_constraints. The violation object provided to this method
|
||||
is exactly the same object returned earlier by
|
||||
find_violated_lazy_constraints. After some training, LearningSolver may
|
||||
decide to proactively build some lazy constraints at the beginning of the
|
||||
optimization process, before a solution is even available. In this case,
|
||||
enforce_lazy_constraints will be called without a corresponding call to
|
||||
find_violated_lazy_constraints.
|
||||
`find_violated_lazy_constraints`. The argument `violation_data` is the
|
||||
user-provided data, previously returned by `find_violated_lazy_constraints`.
|
||||
In the Traveling Salesman Problem, for example, it could be a list of cities
|
||||
in the subtour.
|
||||
|
||||
Note that this method can be called either before the optimization starts or
|
||||
from within a callback. To ensure that constraints are added correctly in
|
||||
either case, it is recommended to use `solver.add_constraint`, instead of
|
||||
modifying the `model` object directly.
|
||||
After some training, LearningSolver may decide to proactively build some lazy
|
||||
constraints at the beginning of the optimization process, before a solution
|
||||
is even available. In this case, `enforce_lazy_constraints` will be called
|
||||
without a corresponding call to `find_violated_lazy_constraints`.
|
||||
|
||||
For a concrete example, see TravelingSalesmanInstance.
|
||||
"""
|
||||
@@ -166,14 +172,14 @@ class Instance(ABC):
|
||||
def has_user_cuts(self) -> bool:
|
||||
return False
|
||||
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
return []
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
return {}
|
||||
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
@@ -3,15 +3,15 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
import gc
|
||||
import os
|
||||
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
||||
import pickle
|
||||
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -71,7 +71,7 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -80,13 +80,13 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation_data)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -95,10 +95,10 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
self.instance.enforce_user_cut(solver, model, violation_data)
|
||||
|
||||
# Input & Output
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@@ -13,7 +13,7 @@ from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -83,7 +83,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[ConstraintName]:
|
||||
) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -92,13 +92,13 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_data: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation_data)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
def find_violated_user_cuts(self, model: Any) -> Dict[ConstraintName, Any]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -107,10 +107,10 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: ConstraintName,
|
||||
violation_name: Any,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
self.instance.enforce_user_cut(solver, model, violation_name)
|
||||
|
||||
@overrides
|
||||
def load(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user