Use np.ndarray for constraint methods in Instance

This commit is contained in:
2021-08-09 20:11:37 -05:00
parent 895cb962b6
commit e852d5cdca
16 changed files with 532 additions and 429 deletions

View File

@@ -9,6 +9,7 @@ from typing import Any, List, TYPE_CHECKING, Dict
import numpy as np
from miplearn.features.sample import Sample, MemorySample
from miplearn.types import ConstraintName, ConstraintCategory
logger = logging.getLogger(__name__)
@@ -97,26 +98,23 @@ class Instance(ABC):
"""
return names
def get_constraint_features(self) -> Dict[str, List[float]]:
return {}
def get_constraint_features(self, names: np.ndarray) -> np.ndarray:
return np.zeros((len(names), 1))
def get_constraint_categories(self) -> Dict[str, str]:
return {}
def has_static_lazy_constraints(self) -> bool:
return False
def get_constraint_categories(self, names: np.ndarray) -> np.ndarray:
return names
def has_dynamic_lazy_constraints(self) -> bool:
return False
def is_constraint_lazy(self, cid: str) -> bool:
return False
def are_constraints_lazy(self, names: np.ndarray) -> np.ndarray:
return np.zeros(len(names), dtype=bool)
def find_violated_lazy_constraints(
self,
solver: "InternalSolver",
model: Any,
) -> List[str]:
) -> List[ConstraintName]:
"""
Returns lazy constraint violations found for the current solution.
@@ -142,7 +140,7 @@ class Instance(ABC):
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> None:
"""
Adds constraints to the model to ensure that the given violation is fixed.
@@ -168,14 +166,14 @@ class Instance(ABC):
def has_user_cuts(self) -> bool:
return False
def find_violated_user_cuts(self, model: Any) -> List[str]:
def find_violated_user_cuts(self, model: Any) -> List[ConstraintName]:
return []
def enforce_user_cut(
self,
solver: "InternalSolver",
model: Any,
violation: str,
violation: ConstraintName,
) -> Any:
return None