mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Update UserCutsComponent
This commit is contained in:
@@ -41,7 +41,7 @@ class DynamicLazyConstraintsComponent(Component):
|
||||
self.classifiers = self.dynamic.classifiers
|
||||
self.thresholds = self.dynamic.thresholds
|
||||
self.known_cids = self.dynamic.known_cids
|
||||
self.lazy_enforced: Set[str] = set()
|
||||
self.lazy_enforced: Set[Hashable] = set()
|
||||
|
||||
@staticmethod
|
||||
def enforce(
|
||||
|
||||
@@ -38,20 +38,19 @@ class UserCutsComponent(Component):
|
||||
self.n_added_in_callback = 0
|
||||
|
||||
@overrides
|
||||
def before_solve_mip_old(
|
||||
def before_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: "Instance",
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
self.enforced.clear()
|
||||
self.n_added_in_callback = 0
|
||||
logger.info("Predicting violated user cuts...")
|
||||
cids = self.dynamic.sample_predict_old(instance, training_data)
|
||||
cids = self.dynamic.sample_predict(instance, sample)
|
||||
logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids))
|
||||
for cid in cids:
|
||||
instance.enforce_user_cut(solver.internal_solver, model, cid)
|
||||
@@ -80,34 +79,27 @@ class UserCutsComponent(Component):
|
||||
logger.debug(f"Added {len(cids)} violated user cuts")
|
||||
|
||||
@overrides
|
||||
def after_solve_mip_old(
|
||||
def after_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: "Instance",
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
sample: Sample,
|
||||
) -> None:
|
||||
training_data.user_cuts_enforced = set(self.enforced)
|
||||
assert sample.after_mip is not None
|
||||
assert sample.after_mip.extra is not None
|
||||
sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced)
|
||||
stats["UserCuts: Added in callback"] = self.n_added_in_callback
|
||||
if self.n_added_in_callback > 0:
|
||||
logger.info(f"{self.n_added_in_callback} user cuts added in callback")
|
||||
|
||||
# Delegate ML methods to self.dynamic
|
||||
# -------------------------------------------------------------------
|
||||
@overrides
|
||||
def sample_xy_old(
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
return self.dynamic.sample_xy_old(instance, sample)
|
||||
|
||||
@overrides
|
||||
def sample_xy(
|
||||
self,
|
||||
instance: Optional[Instance],
|
||||
instance: "Instance",
|
||||
sample: Sample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
return self.dynamic.sample_xy(instance, sample)
|
||||
@@ -115,13 +107,13 @@ class UserCutsComponent(Component):
|
||||
def sample_predict(
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: TrainingSample,
|
||||
sample: Sample,
|
||||
) -> List[Hashable]:
|
||||
return self.dynamic.sample_predict(instance, sample)
|
||||
|
||||
@overrides
|
||||
def fit_old(self, training_instances: List["Instance"]) -> None:
|
||||
self.dynamic.fit_old(training_instances)
|
||||
def fit(self, training_instances: List["Instance"]) -> None:
|
||||
self.dynamic.fit(training_instances)
|
||||
|
||||
@overrides
|
||||
def fit_xy(
|
||||
@@ -132,9 +124,9 @@ class UserCutsComponent(Component):
|
||||
self.dynamic.fit_xy(x, y)
|
||||
|
||||
@overrides
|
||||
def sample_evaluate_old(
|
||||
def sample_evaluate(
|
||||
self,
|
||||
instance: "Instance",
|
||||
sample: TrainingSample,
|
||||
sample: Sample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
return self.dynamic.sample_evaluate_old(instance, sample)
|
||||
return self.dynamic.sample_evaluate(instance, sample)
|
||||
|
||||
@@ -80,9 +80,10 @@ def test_usage(
|
||||
solver: LearningSolver,
|
||||
) -> None:
|
||||
stats_before = solver.solve(stab_instance)
|
||||
sample = stab_instance.training_data[0]
|
||||
assert sample.user_cuts_enforced is not None
|
||||
assert len(sample.user_cuts_enforced) > 0
|
||||
sample = stab_instance.samples[0]
|
||||
assert sample.after_mip is not None
|
||||
assert sample.after_mip.extra is not None
|
||||
assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0
|
||||
print(stats_before)
|
||||
assert stats_before["UserCuts: Added ahead-of-time"] == 0
|
||||
assert stats_before["UserCuts: Added in callback"] > 0
|
||||
|
||||
@@ -67,6 +67,8 @@ def test_subtour() -> None:
|
||||
instance = TravelingSalesmanInstance(n_cities, distances)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
assert instance.samples[0].after_mip is not None
|
||||
assert instance.samples[0].after_mip.extra is not None
|
||||
lazy_enforced = instance.samples[0].after_mip.extra["lazy_enforced"]
|
||||
assert lazy_enforced is not None
|
||||
assert len(lazy_enforced) > 0
|
||||
|
||||
Reference in New Issue
Block a user