mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Redesign InternalSolver constraint methods
This commit is contained in:
@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
|
||||
|
||||
# noinspection PyMethodMayBeStatic
|
||||
class Instance(ABC, EnforceOverrides):
|
||||
"""
|
||||
@@ -170,7 +171,12 @@ class Instance(ABC, EnforceOverrides):
|
||||
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
|
||||
return []
|
||||
|
||||
def build_user_cut(self, model: Any, violation: Hashable) -> Any:
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
def load(self) -> None:
|
||||
|
||||
@@ -106,9 +106,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@overrides
|
||||
def build_user_cut(self, model: Any, violation: Hashable) -> None:
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: Hashable,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.build_user_cut(model, violation)
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
|
||||
@overrides
|
||||
def load(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user