mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
generalizing solvers for LearningSolver; ensuring primal does not get set if it does not exist
This commit is contained in:
@@ -130,6 +130,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
|
|
||||||
def before_solve(self, solver, instance, model):
|
def before_solve(self, solver, instance, model):
|
||||||
solution = self.predict(instance)
|
solution = self.predict(instance)
|
||||||
|
if solution is None:
|
||||||
|
return
|
||||||
if self.mode == "heuristic":
|
if self.mode == "heuristic":
|
||||||
solver.internal_solver.fix(solution)
|
solver.internal_solver.fix(solution)
|
||||||
else:
|
else:
|
||||||
@@ -185,6 +187,7 @@ class PrimalSolutionComponent(Component):
|
|||||||
x_test = VariableFeaturesExtractor().extract([instance])
|
x_test = VariableFeaturesExtractor().extract([instance])
|
||||||
solution = {}
|
solution = {}
|
||||||
var_split = Extractor.split_variables(instance)
|
var_split = Extractor.split_variables(instance)
|
||||||
|
all_none = True
|
||||||
for category in var_split.keys():
|
for category in var_split.keys():
|
||||||
for (i, (var, index)) in enumerate(var_split[category]):
|
for (i, (var, index)) in enumerate(var_split[category]):
|
||||||
if var not in solution.keys():
|
if var not in solution.keys():
|
||||||
@@ -198,4 +201,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
(var, index, ws[i, 1], self.thresholds[category, label]))
|
(var, index, ws[i, 1], self.thresholds[category, label]))
|
||||||
if ws[i, 1] >= self.thresholds[category, label]:
|
if ws[i, 1] >= self.thresholds[category, label]:
|
||||||
solution[var][index] = label
|
solution[var][index] = label
|
||||||
|
if all_none:
|
||||||
|
all_none = False
|
||||||
|
if all_none:
|
||||||
|
return None
|
||||||
return solution
|
return solution
|
||||||
|
|||||||
@@ -43,8 +43,6 @@ class InternalSolver:
|
|||||||
self.var_name_to_var = {}
|
self.var_name_to_var = {}
|
||||||
|
|
||||||
def solve_lp(self, tee=False):
|
def solve_lp(self, tee=False):
|
||||||
self.solver.set_instance(self.model)
|
|
||||||
|
|
||||||
# Relax domain
|
# Relax domain
|
||||||
from pyomo.core.base.set_types import Reals, Binary
|
from pyomo.core.base.set_types import Reals, Binary
|
||||||
original_domains = []
|
original_domains = []
|
||||||
@@ -274,6 +272,8 @@ class LearningSolver:
|
|||||||
solver = CPLEXSolver()
|
solver = CPLEXSolver()
|
||||||
elif self.internal_solver_factory == "gurobi":
|
elif self.internal_solver_factory == "gurobi":
|
||||||
solver = GurobiSolver()
|
solver = GurobiSolver()
|
||||||
|
elif issubclass(self.internal_solver_factory, InternalSolver):
|
||||||
|
solver = self.internal_solver_factory()
|
||||||
else:
|
else:
|
||||||
raise Exception("solver %s not supported" % self.internal_solver_factory)
|
raise Exception("solver %s not supported" % self.internal_solver_factory)
|
||||||
solver.set_threads(self.threads)
|
solver.set_threads(self.threads)
|
||||||
|
|||||||
Reference in New Issue
Block a user