mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Use np.ndarray for constraint methods in Instance
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
|
||||
import numpy as np
|
||||
|
||||
from miplearn.features.sample import Sample, MemorySample
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,26 +98,23 @@ class Instance(ABC):
|
||||
"""
|
||||
return names
|
||||
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
return {}
|
||||
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
|
||||
return np.zeros((len(names), 1))
|
||||
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
return False
|
||||
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
return names
|
||||
|
||||
def has_dynamic_lazy_constraints(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_constraint_lazy(self, cid: str) -> bool:
|
||||
return False
|
||||
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
|
||||
return np.zeros(len(names), dtype=bool)
|
||||
|
||||
def find_violated_lazy_constraints(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[str]:
|
||||
) -> List[ConstraintName]:
|
||||
"""
|
||||
Returns lazy constraint violations found for the current solution.
|
||||
|
||||
@@ -142,7 +140,7 @@ class Instance(ABC):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation: ConstraintName,
|
||||
) -> None:
|
||||
"""
|
||||
Adds constraints to the model to ensure that the given violation is fixed.
|
||||
@@ -168,14 +166,14 @@ class Instance(ABC):
|
||||
def has_user_cuts(self) -> bool:
|
||||
return False
|
||||
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
return []
|
||||
|
||||
def enforce_user_cut(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation: ConstraintName,
|
||||
) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample, Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -46,19 +47,14 @@ class FileInstance(Instance):
|
||||
return self.instance.get_variable_categories(names)
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_features()
|
||||
return self.instance.get_constraint_features(names)
|
||||
|
||||
@overrides
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_categories()
|
||||
|
||||
@overrides
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
assert self.instance is not None
|
||||
return self.instance.has_static_lazy_constraints()
|
||||
return self.instance.get_constraint_categories(names)
|
||||
|
||||
@overrides
|
||||
def has_dynamic_lazy_constraints(self) -> bool:
|
||||
@@ -66,16 +62,16 @@ class FileInstance(Instance):
|
||||
return self.instance.has_dynamic_lazy_constraints()
|
||||
|
||||
@overrides
|
||||
def is_constraint_lazy(self, cid: str) -> bool:
|
||||
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.is_constraint_lazy(cid)
|
||||
return self.instance.are_constraints_lazy(names)
|
||||
|
||||
@overrides
|
||||
def find_violated_lazy_constraints(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[str]:
|
||||
) -> List[ConstraintName]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -84,13 +80,13 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation: ConstraintName,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -99,7 +95,7 @@ class FileInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation: ConstraintName,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
|
||||
@@ -13,6 +13,7 @@ from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import ConstraintName, ConstraintCategory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import InternalSolver
|
||||
@@ -58,19 +59,14 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.get_variable_categories(names)
|
||||
|
||||
@overrides
|
||||
def get_constraint_features(self) -> Dict[str, List[float]]:
|
||||
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_features()
|
||||
return self.instance.get_constraint_features(names)
|
||||
|
||||
@overrides
|
||||
def get_constraint_categories(self) -> Dict[str, str]:
|
||||
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_constraint_categories()
|
||||
|
||||
@overrides
|
||||
def has_static_lazy_constraints(self) -> bool:
|
||||
assert self.instance is not None
|
||||
return self.instance.has_static_lazy_constraints()
|
||||
return self.instance.get_constraint_categories(names)
|
||||
|
||||
@overrides
|
||||
def has_dynamic_lazy_constraints(self) -> bool:
|
||||
@@ -78,16 +74,16 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.has_dynamic_lazy_constraints()
|
||||
|
||||
@overrides
|
||||
def is_constraint_lazy(self, cid: str) -> bool:
|
||||
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.is_constraint_lazy(cid)
|
||||
return self.instance.are_constraints_lazy(names)
|
||||
|
||||
@overrides
|
||||
def find_violated_lazy_constraints(
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
) -> List[str]:
|
||||
) -> List[ConstraintName]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_lazy_constraints(solver, model)
|
||||
|
||||
@@ -96,13 +92,13 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation: ConstraintName,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_lazy_constraint(solver, model, violation)
|
||||
|
||||
@overrides
|
||||
def find_violated_user_cuts(self, model: Any) -> List[str]:
|
||||
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
|
||||
assert self.instance is not None
|
||||
return self.instance.find_violated_user_cuts(model)
|
||||
|
||||
@@ -111,7 +107,7 @@ class PickleGzInstance(Instance):
|
||||
self,
|
||||
solver: "InternalSolver",
|
||||
model: Any,
|
||||
violation: str,
|
||||
violation: ConstraintName,
|
||||
) -> None:
|
||||
assert self.instance is not None
|
||||
self.instance.enforce_user_cut(solver, model, violation)
|
||||
|
||||
Reference in New Issue
Block a user