diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 5f8183a..7864995 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -101,11 +101,16 @@ class LearningSolver: solve_lp: bool = True, simulate_perfect: bool = False, extractor: Optional[FeaturesExtractor] = None, + extract_lhs: bool = True, + extract_sa: bool = True, ) -> None: if solver is None: solver = GurobiPyomoSolver() if extractor is None: - extractor = FeaturesExtractor() + extractor = FeaturesExtractor( + with_sa=extract_sa, + with_lhs=extract_lhs, + ) assert isinstance(solver, InternalSolver) self.components: Dict[str, Component] = {} self.internal_solver: Optional[InternalSolver] = None