mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Use np.ndarray for constraint methods in Instance
This commit is contained in:
@@ -15,7 +15,7 @@ from miplearn.components.component import Component
|
||||
from miplearn.components.dynamic_common import DynamicConstraintsComponent
|
||||
from miplearn.features.sample import Sample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import LearningSolveStats
|
||||
from miplearn.types import LearningSolveStats, ConstraintName, ConstraintCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +34,7 @@ class UserCutsComponent(Component):
|
||||
threshold=threshold,
|
||||
attr="mip_user_cuts_enforced",
|
||||
)
|
||||
self.enforced: Set[str] = set()
|
||||
self.enforced: Set[ConstraintName] = set()
|
||||
self.n_added_in_callback = 0
|
||||
|
||||
@overrides
|
||||
@@ -71,7 +71,7 @@ class UserCutsComponent(Component):
|
||||
for cid in cids:
|
||||
if cid in self.enforced:
|
||||
continue
|
||||
assert isinstance(cid, str)
|
||||
assert isinstance(cid, ConstraintName)
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
self.enforced.add(cid)
|
||||
self.n_added_in_callback += 1
|
||||
@@ -110,7 +110,7 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> List[str]:
|
||||
) -> List[ConstraintName]:
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
@@ -120,8 +120,8 @@ class UserCutsComponent(Component):
|
||||
@overrides
|
||||
def fit_xy(
|
||||
self,
|
||||
x: Dict[str, np.ndarray],
|
||||
y: Dict[str, np.ndarray],
|
||||
x: Dict[ConstraintCategory, np.ndarray],
|
||||
y: Dict[ConstraintCategory, np.ndarray],
|
||||
) -> None:
|
||||
self.dynamic.fit_xy(x, y)
|
||||
|
||||
@@ -130,5 +130,5 @@ class UserCutsComponent(Component):
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
) -> Dict[ConstraintCategory, Dict[str, float]]:
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
Reference in New Issue
Block a user