mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Remove {get,put}_set and deprecated functions
This commit is contained in:
@@ -56,6 +56,11 @@ class DynamicConstraintsComponent(Component):
|
||||
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
|
||||
known_cids = np.array(self.known_cids, dtype="S")
|
||||
|
||||
enforced_cids = None
|
||||
enforced_cids_np = sample.get_array(self.attr)
|
||||
if enforced_cids_np is not None:
|
||||
enforced_cids = list(enforced_cids_np)
|
||||
|
||||
# Get user-provided constraint features
|
||||
(
|
||||
constr_features,
|
||||
@@ -72,13 +77,11 @@ class DynamicConstraintsComponent(Component):
|
||||
constr_features,
|
||||
]
|
||||
)
|
||||
assert len(known_cids) == constr_features.shape[0]
|
||||
|
||||
categories = np.unique(constr_categories)
|
||||
for c in categories:
|
||||
x[c] = constr_features[constr_categories == c].tolist()
|
||||
cids[c] = known_cids[constr_categories == c].tolist()
|
||||
enforced_cids = np.array(list(sample.get_set(self.attr)), dtype="S")
|
||||
if enforced_cids is not None:
|
||||
tmp = np.isin(cids[c], enforced_cids).reshape(-1, 1)
|
||||
y[c] = np.hstack([~tmp, tmp]).tolist() # type: ignore
|
||||
@@ -99,7 +102,7 @@ class DynamicConstraintsComponent(Component):
|
||||
assert pre is not None
|
||||
known_cids: Set = set()
|
||||
for cids in pre:
|
||||
known_cids |= cids
|
||||
known_cids |= set(list(cids))
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(known_cids))
|
||||
|
||||
@@ -128,7 +131,7 @@ class DynamicConstraintsComponent(Component):
|
||||
|
||||
@overrides
|
||||
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
|
||||
return sample.get_set(self.attr)
|
||||
return sample.get_array(self.attr)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
@@ -150,7 +153,7 @@ class DynamicConstraintsComponent(Component):
|
||||
instance: Instance,
|
||||
sample: Sample,
|
||||
) -> Dict[str, float]:
|
||||
actual = sample.get_set(self.attr)
|
||||
actual = sample.get_array(self.attr)
|
||||
assert actual is not None
|
||||
pred = set(self.sample_predict(instance, sample))
|
||||
tp, tn, fp, fn = 0, 0, 0, 0
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import pdb
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, Any, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
@@ -78,7 +79,10 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_set("mip_constr_lazy_enforced", set(self.lazy_enforced))
|
||||
sample.put_array(
|
||||
"mip_constr_lazy_enforced",
|
||||
np.array(list(self.lazy_enforced), dtype="S"),
|
||||
)
|
||||
|
||||
@overrides
|
||||
def iteration_cb(
|
||||
|
||||
@@ -87,7 +87,10 @@ class UserCutsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_set("mip_user_cuts_enforced", set(self.enforced))
|
||||
sample.put_array(
|
||||
"mip_user_cuts_enforced",
|
||||
np.array(list(self.enforced), dtype="S"),
|
||||
)
|
||||
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
||||
if self.n_added_in_callback > 0:
|
||||
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
||||
|
||||
@@ -61,7 +61,10 @@ class StaticLazyConstraintsComponent(Component):
|
||||
stats: LearningSolveStats,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
sample.put_set("mip_constr_lazy_enforced", self.enforced_cids)
|
||||
sample.put_array(
|
||||
"mip_constr_lazy_enforced",
|
||||
np.array(list(self.enforced_cids), dtype="S"),
|
||||
)
|
||||
stats["LazyStatic: Restored"] = self.n_restored
|
||||
stats["LazyStatic: Iterations"] = self.n_iterations
|
||||
|
||||
@@ -212,7 +215,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
constr_names = sample.get_array("static_constr_names")
|
||||
constr_categories = sample.get_array("static_constr_categories")
|
||||
constr_lazy = sample.get_array("static_constr_lazy")
|
||||
lazy_enforced = sample.get_set("mip_constr_lazy_enforced")
|
||||
lazy_enforced = sample.get_array("mip_constr_lazy_enforced")
|
||||
if constr_features is None:
|
||||
constr_features = sample.get_array("static_constr_features")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user