From 39597287a668e24ef2db5e5b9afd182c150cf7a8 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Thu, 15 Apr 2021 09:57:10 -0500 Subject: [PATCH] Make extractor configurable --- miplearn/features.py | 7 +++---- miplearn/solvers/learning.py | 12 +++++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/miplearn/features.py b/miplearn/features.py index 7e141a5..a1c7cba 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -132,23 +132,22 @@ class Sample: class FeaturesExtractor: def __init__( self, - internal_solver: "InternalSolver", with_sa: bool = True, ) -> None: - self.solver = internal_solver self.with_sa = with_sa def extract( self, instance: "Instance", + solver: "InternalSolver", with_static: bool = True, ) -> Features: features = Features() - features.variables = self.solver.get_variables( + features.variables = solver.get_variables( with_static=with_static, with_sa=self.with_sa, ) - features.constraints = self.solver.get_constraints( + features.constraints = solver.get_constraints( with_static=with_static, ) if with_static: diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 9d41b93..93c2350 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -99,9 +99,12 @@ class LearningSolver: use_lazy_cb: bool = False, solve_lp: bool = True, simulate_perfect: bool = False, + extractor: Optional[FeaturesExtractor] = None, ) -> None: if solver is None: solver = GurobiPyomoSolver() + if extractor is None: + extractor = FeaturesExtractor() assert isinstance(solver, InternalSolver) self.components: Dict[str, Component] = {} self.internal_solver: Optional[InternalSolver] = None @@ -111,6 +114,7 @@ class LearningSolver: self.solve_lp: bool = solve_lp self.tee = False self.use_lazy_cb: bool = use_lazy_cb + self.extractor = extractor if components is not None: for comp in components: self._add_component(comp) @@ -156,7 +160,7 @@ class LearningSolver: # Extract features (after-load) # ------------------------------------------------------- logger.info("Extracting features (after-load)...") - features = FeaturesExtractor(self.internal_solver).extract(instance) + features = self.extractor.extract(instance, self.internal_solver) features.extra = {} sample.after_load = features @@ -187,8 +191,9 @@ class LearningSolver: # Extract features (after-lp) # ------------------------------------------------------- logger.info("Extracting features (after-lp)...") - features = FeaturesExtractor(self.internal_solver).extract( + features = self.extractor.extract( instance, + self.internal_solver, with_static=False, ) features.extra = {} @@ -252,8 +257,9 @@ class LearningSolver: # Extract features (after-mip) # ------------------------------------------------------- logger.info("Extracting features (after-mip)...") - features = FeaturesExtractor(self.internal_solver).extract( + features = self.extractor.extract( instance, + self.internal_solver, with_static=False, ) features.mip_solve = mip_stats