diff --git a/miplearn/solvers.py b/miplearn/solvers.py index 5c42699..f262cd2 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -29,7 +29,7 @@ class LearningSolver: gap_limit=None, internal_solver_factory=_gurobi_factory, components=None, - mode=None): + mode="exact"): self.is_persistent = None self.internal_solver = None self.components = components @@ -38,19 +38,18 @@ class LearningSolver: self.time_limit = time_limit self.gap_limit = gap_limit self.tee = False + self.mode = mode if self.components is not None: assert isinstance(self.components, dict) else: self.components = { "warm-start": WarmStartComponent(), - #"branch-priority": BranchPriorityComponent(), } - if mode is not None: - assert mode in ["exact", "heuristic"] - for component in self.components.values(): - component.mode = mode + assert self.mode in ["exact", "heuristic"] + for component in self.components.values(): + component.mode = self.mode def _create_solver(self): self.internal_solver = self.internal_solver_factory() @@ -82,7 +81,10 @@ class LearningSolver: else: solve_results = self.internal_solver.solve(model, tee=tee, warmstart=is_warm_start_available) - solve_results["Solver"][0]["Nodes"] = self.internal_solver._solver_model.getAttr("NodeCount") + if hasattr(self.internal_solver, "_solver_model"): + solve_results["Solver"][0]["Nodes"] = self.internal_solver._solver_model.getAttr("NodeCount") + else: + solve_results["Solver"][0]["Nodes"] = 1 for component in self.components.values(): component.after_solve(self, instance, model)