Use np.ndarray for constraint methods in Instance

This commit is contained in:
2021-08-09 20:11:37 -05:00
parent 895cb962b6
commit e852d5cdca
16 changed files with 532 additions and 429 deletions

View File

@@ -15,7 +15,7 @@ from miplearn.components.component import Component
from miplearn.components.dynamic_common import DynamicConstraintsComponent
from miplearn.features.sample import Sample
from miplearn.instance.base import Instance
from miplearn.types import LearningSolveStats
from miplearn.types import LearningSolveStats, ConstraintName, ConstraintCategory
logger = logging.getLogger(__name__)
@@ -34,7 +34,7 @@ class UserCutsComponent(Component):
threshold=threshold,
attr="mip_user_cuts_enforced",
)
self.enforced: Set[str] = set()
self.enforced: Set[ConstraintName] = set()
self.n_added_in_callback = 0
@overrides
@@ -71,7 +71,7 @@ class UserCutsComponent(Component):
for cid in cids:
if cid in self.enforced:
continue
assert isinstance(cid, str)
assert isinstance(cid, ConstraintName)
instance.enforce_user_cut(solver.internal_solver, model, cid)
self.enforced.add(cid)
self.n_added_in_callback += 1
@@ -110,7 +110,7 @@ class UserCutsComponent(Component):
self,
instance: "Instance",
sample: Sample,
) -> List[str]:
) -> List[ConstraintName]:
return self.dynamic.sample_predict(instance, sample)
@overrides
@@ -120,8 +120,8 @@ class UserCutsComponent(Component):
@overrides
def fit_xy(
self,
x: Dict[str, np.ndarray],
y: Dict[str, np.ndarray],
x: Dict[ConstraintCategory, np.ndarray],
y: Dict[ConstraintCategory, np.ndarray],
) -> None:
self.dynamic.fit_xy(x, y)
@@ -130,5 +130,5 @@ class UserCutsComponent(Component):
self,
instance: "Instance",
sample: Sample,
) -> Dict[str, Dict[str, float]]:
) -> Dict[ConstraintCategory, Dict[str, float]]:
return self.dynamic.sample_evaluate(instance, sample)