diff --git a/src/python/miplearn/components/primal.py b/src/python/miplearn/components/primal.py index bc3b9f4..7bc8d47 100644 --- a/src/python/miplearn/components/primal.py +++ b/src/python/miplearn/components/primal.py @@ -130,6 +130,8 @@ class PrimalSolutionComponent(Component): def before_solve(self, solver, instance, model): solution = self.predict(instance) + if solution is None: + return if self.mode == "heuristic": solver.internal_solver.fix(solution) else: @@ -185,6 +187,7 @@ class PrimalSolutionComponent(Component): x_test = VariableFeaturesExtractor().extract([instance]) solution = {} var_split = Extractor.split_variables(instance) + all_none = True for category in var_split.keys(): for (i, (var, index)) in enumerate(var_split[category]): if var not in solution.keys(): @@ -198,4 +201,8 @@ class PrimalSolutionComponent(Component): (var, index, ws[i, 1], self.thresholds[category, label])) if ws[i, 1] >= self.thresholds[category, label]: solution[var][index] = label + if all_none: + all_none = False + if all_none: + return None return solution diff --git a/src/python/miplearn/solvers.py b/src/python/miplearn/solvers.py index 5ad3ef6..bee0734 100644 --- a/src/python/miplearn/solvers.py +++ b/src/python/miplearn/solvers.py @@ -43,8 +43,6 @@ class InternalSolver: self.var_name_to_var = {} def solve_lp(self, tee=False): - self.solver.set_instance(self.model) - # Relax domain from pyomo.core.base.set_types import Reals, Binary original_domains = [] @@ -274,6 +272,8 @@ class LearningSolver: solver = CPLEXSolver() elif self.internal_solver_factory == "gurobi": solver = GurobiSolver() + elif issubclass(self.internal_solver_factory, InternalSolver): + solver = self.internal_solver_factory() else: raise Exception("solver %s not supported" % self.internal_solver_factory) solver.set_threads(self.threads)