diff --git a/miplearn/solvers.py b/miplearn/solvers.py index 9e2bcb5..e28089b 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -57,7 +57,7 @@ class LearningSolver: var[index].value = 0 # Solve MILP - self.parent_solver.solve(model, tee=tee, warmstart=True) + self._solve(model, tee=tee) # Update y_train for category in var_split.keys(): @@ -78,5 +78,9 @@ class LearningSolver: self.ws_predictors[category] = WarmStartPredictor() self.ws_predictors[category].fit(x_train, y_train) - def _solve(self, tee): - self.parent_solver.solve(tee=tee) \ No newline at end of file + def _solve(self, model, tee=False): + if hasattr(self.parent_solver, "set_instance"): + self.parent_solver.set_instance(model) + self.parent_solver.solve(tee=tee, warmstart=True) + else: + self.parent_solver.solve(model, tee=tee, warmstart=True)