From c26b852c67e3518b5a843596d8d971c89af35c8b Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 13 Apr 2021 09:05:34 -0500 Subject: [PATCH] Update UserCutsComponent --- miplearn/components/dynamic_lazy.py | 2 +- miplearn/components/dynamic_user_cuts.py | 38 +++++++++------------- tests/components/test_dynamic_user_cuts.py | 7 ++-- tests/problems/test_tsp.py | 2 ++ 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/miplearn/components/dynamic_lazy.py b/miplearn/components/dynamic_lazy.py index d676af2..ad2f3fb 100644 --- a/miplearn/components/dynamic_lazy.py +++ b/miplearn/components/dynamic_lazy.py @@ -41,7 +41,7 @@ class DynamicLazyConstraintsComponent(Component): self.classifiers = self.dynamic.classifiers self.thresholds = self.dynamic.thresholds self.known_cids = self.dynamic.known_cids - self.lazy_enforced: Set[str] = set() + self.lazy_enforced: Set[Hashable] = set() @staticmethod def enforce( diff --git a/miplearn/components/dynamic_user_cuts.py b/miplearn/components/dynamic_user_cuts.py index 67e9037..3a24400 100644 --- a/miplearn/components/dynamic_user_cuts.py +++ b/miplearn/components/dynamic_user_cuts.py @@ -38,20 +38,19 @@ class UserCutsComponent(Component): self.n_added_in_callback = 0 @overrides - def before_solve_mip_old( + def before_solve_mip( self, solver: "LearningSolver", instance: "Instance", model: Any, stats: LearningSolveStats, - features: Features, - training_data: TrainingSample, + sample: Sample, ) -> None: assert solver.internal_solver is not None self.enforced.clear() self.n_added_in_callback = 0 logger.info("Predicting violated user cuts...") - cids = self.dynamic.sample_predict_old(instance, training_data) + cids = self.dynamic.sample_predict(instance, sample) logger.info("Enforcing %d user cuts ahead-of-time..." % len(cids)) for cid in cids: instance.enforce_user_cut(solver.internal_solver, model, cid) @@ -80,34 +79,27 @@ class UserCutsComponent(Component): logger.debug(f"Added {len(cids)} violated user cuts") @overrides - def after_solve_mip_old( + def after_solve_mip( self, solver: "LearningSolver", instance: "Instance", model: Any, stats: LearningSolveStats, - features: Features, - training_data: TrainingSample, + sample: Sample, ) -> None: - training_data.user_cuts_enforced = set(self.enforced) + assert sample.after_mip is not None + assert sample.after_mip.extra is not None + sample.after_mip.extra["user_cuts_enforced"] = set(self.enforced) stats["UserCuts: Added in callback"] = self.n_added_in_callback if self.n_added_in_callback > 0: logger.info(f"{self.n_added_in_callback} user cuts added in callback") # Delegate ML methods to self.dynamic # ------------------------------------------------------------------- - @overrides - def sample_xy_old( - self, - instance: "Instance", - sample: TrainingSample, - ) -> Tuple[Dict, Dict]: - return self.dynamic.sample_xy_old(instance, sample) - @overrides def sample_xy( self, - instance: Optional[Instance], + instance: "Instance", sample: Sample, ) -> Tuple[Dict, Dict]: return self.dynamic.sample_xy(instance, sample) @@ -115,13 +107,13 @@ class UserCutsComponent(Component): def sample_predict( self, instance: "Instance", - sample: TrainingSample, + sample: Sample, ) -> List[Hashable]: return self.dynamic.sample_predict(instance, sample) @overrides - def fit_old(self, training_instances: List["Instance"]) -> None: - self.dynamic.fit_old(training_instances) + def fit(self, training_instances: List["Instance"]) -> None: + self.dynamic.fit(training_instances) @overrides def fit_xy( @@ -132,9 +124,9 @@ class UserCutsComponent(Component): self.dynamic.fit_xy(x, y) @overrides - def sample_evaluate_old( + def sample_evaluate( self, instance: "Instance", - sample: TrainingSample, + sample: Sample, ) -> Dict[Hashable, Dict[str, float]]: - return self.dynamic.sample_evaluate_old(instance, sample) + return self.dynamic.sample_evaluate(instance, sample) diff --git a/tests/components/test_dynamic_user_cuts.py b/tests/components/test_dynamic_user_cuts.py index 199efed..45a8988 100644 --- a/tests/components/test_dynamic_user_cuts.py +++ b/tests/components/test_dynamic_user_cuts.py @@ -80,9 +80,10 @@ def test_usage( solver: LearningSolver, ) -> None: stats_before = solver.solve(stab_instance) - sample = stab_instance.training_data[0] - assert sample.user_cuts_enforced is not None - assert len(sample.user_cuts_enforced) > 0 + sample = stab_instance.samples[0] + assert sample.after_mip is not None + assert sample.after_mip.extra is not None + assert len(sample.after_mip.extra["user_cuts_enforced"]) > 0 print(stats_before) assert stats_before["UserCuts: Added ahead-of-time"] == 0 assert stats_before["UserCuts: Added in callback"] > 0 diff --git a/tests/problems/test_tsp.py b/tests/problems/test_tsp.py index 37288e2..4f28997 100644 --- a/tests/problems/test_tsp.py +++ b/tests/problems/test_tsp.py @@ -67,6 +67,8 @@ def test_subtour() -> None: instance = TravelingSalesmanInstance(n_cities, distances) solver = LearningSolver() solver.solve(instance) + assert instance.samples[0].after_mip is not None + assert instance.samples[0].after_mip.extra is not None lazy_enforced = instance.samples[0].after_mip.extra["lazy_enforced"] assert lazy_enforced is not None assert len(lazy_enforced) > 0