Use np.ndarray for constraint methods in Instance

This commit is contained in:
2021-08-09 20:11:37 -05:00
parent 895cb962b6
commit e852d5cdca
16 changed files with 532 additions and 429 deletions

View File

@@ -9,6 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
import numpy as np
from miplearn.features.sample import Sample, MemorySample
from miplearn.types import ConstraintName, ConstraintCategory
logger = logging.getLogger(__name__)
@@ -97,26 +98,23 @@ class Instance(ABC):
"""
return names
def get_constraint_features(self) -> Dict[str, List[float]]:
return {}
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
return np.zeros((len(names), 1))
def get_constraint_categories(self) -> Dict[str, str]:
return {}
def has_static_lazy_constraints(self) -> bool:
return False
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
return names
def has_dynamic_lazy_constraints(self) -> bool:
return False
def is_constraint_lazy(self, cid: str) -> bool:
return False
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
return np.zeros(len(names), dtype=bool)
def find_violated_lazy_constraints(
self,
solver: "InternalSolver",
model: Any,
) -> List[str]:
) -> List[ConstraintName]:
"""
Returns lazy constraint violations found for the current solution.
@@ -142,7 +140,7 @@ class Instance(ABC):
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> None:
"""
Adds constraints to the model to ensure that the given violation is fixed.
@@ -168,14 +166,14 @@ class Instance(ABC):
def has_user_cuts(self) -> bool:
return False
def find_violated_user_cuts(self, model: Any) -> List[str]:
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
return []
def enforce_user_cut(
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> Any:
return None

View File

@@ -11,6 +11,7 @@ from overrides import overrides
from miplearn.features.sample import Hdf5Sample, Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName, ConstraintCategory
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -46,19 +47,14 @@ class FileInstance(Instance):
return self.instance.get_variable_categories(names)
@overrides
def get_constraint_features(self) -> Dict[str, List[float]]:
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_constraint_features()
return self.instance.get_constraint_features(names)
@overrides
def get_constraint_categories(self) -> Dict[str, str]:
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_constraint_categories()
@overrides
def has_static_lazy_constraints(self) -> bool:
assert self.instance is not None
return self.instance.has_static_lazy_constraints()
return self.instance.get_constraint_categories(names)
@overrides
def has_dynamic_lazy_constraints(self) -> bool:
@@ -66,16 +62,16 @@ class FileInstance(Instance):
return self.instance.has_dynamic_lazy_constraints()
@overrides
def is_constraint_lazy(self, cid: str) -> bool:
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.is_constraint_lazy(cid)
return self.instance.are_constraints_lazy(names)
@overrides
def find_violated_lazy_constraints(
self,
solver: "InternalSolver",
model: Any,
) -> List[str]:
) -> List[ConstraintName]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(solver, model)
@@ -84,13 +80,13 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> None:
assert self.instance is not None
self.instance.enforce_lazy_constraint(solver, model, violation)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[str]:
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@@ -99,7 +95,7 @@ class FileInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> None:
assert self.instance is not None
self.instance.enforce_user_cut(solver, model, violation)

View File

@@ -13,6 +13,7 @@ from overrides import overrides
from miplearn.features.sample import Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName, ConstraintCategory
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -58,19 +59,14 @@ class PickleGzInstance(Instance):
return self.instance.get_variable_categories(names)
@overrides
def get_constraint_features(self) -> Dict[str, List[float]]:
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_constraint_features()
return self.instance.get_constraint_features(names)
@overrides
def get_constraint_categories(self) -> Dict[str, str]:
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.get_constraint_categories()
@overrides
def has_static_lazy_constraints(self) -> bool:
assert self.instance is not None
return self.instance.has_static_lazy_constraints()
return self.instance.get_constraint_categories(names)
@overrides
def has_dynamic_lazy_constraints(self) -> bool:
@@ -78,16 +74,16 @@ class PickleGzInstance(Instance):
return self.instance.has_dynamic_lazy_constraints()
@overrides
def is_constraint_lazy(self, cid: str) -> bool:
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
assert self.instance is not None
return self.instance.is_constraint_lazy(cid)
return self.instance.are_constraints_lazy(names)
@overrides
def find_violated_lazy_constraints(
self,
solver: "InternalSolver",
model: Any,
) -> List[str]:
) -> List[ConstraintName]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(solver, model)
@@ -96,13 +92,13 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> None:
assert self.instance is not None
self.instance.enforce_lazy_constraint(solver, model, violation)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[str]:
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@@ -111,7 +107,7 @@ class PickleGzInstance(Instance):
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> None:
assert self.instance is not None
self.instance.enforce_user_cut(solver, model, violation)