generalizing solvers for LearningSolver; ensuring primal does not get set if it does not exist

pull/3/head
bknueven 6 years ago committed by Alinson S Xavier
parent 2a558931d1
commit ff3efe4884

@ -130,6 +130,8 @@ class PrimalSolutionComponent(Component):
def before_solve(self, solver, instance, model):
solution = self.predict(instance)
if solution is None:
return
if self.mode == "heuristic":
solver.internal_solver.fix(solution)
else:
@ -185,6 +187,7 @@ class PrimalSolutionComponent(Component):
x_test = VariableFeaturesExtractor().extract([instance])
solution = {}
var_split = Extractor.split_variables(instance)
all_none = True
for category in var_split.keys():
for (i, (var, index)) in enumerate(var_split[category]):
if var not in solution.keys():
@ -198,4 +201,8 @@ class PrimalSolutionComponent(Component):
(var, index, ws[i, 1], self.thresholds[category, label]))
if ws[i, 1] >= self.thresholds[category, label]:
solution[var][index] = label
if all_none:
all_none = False
if all_none:
return None
return solution

@ -43,8 +43,6 @@ class InternalSolver:
self.var_name_to_var = {}
def solve_lp(self, tee=False):
self.solver.set_instance(self.model)
# Relax domain
from pyomo.core.base.set_types import Reals, Binary
original_domains = []
@ -274,6 +272,8 @@ class LearningSolver:
solver = CPLEXSolver()
elif self.internal_solver_factory == "gurobi":
solver = GurobiSolver()
elif issubclass(self.internal_solver_factory, InternalSolver):
solver = self.internal_solver_factory()
else:
raise Exception("solver %s not supported" % self.internal_solver_factory)
solver.set_threads(self.threads)

Loading…
Cancel
Save