Make LearningSolver picklable

pull/3/head
Alinson S. Xavier 6 years ago
parent 43225681cb
commit da32ff2f35

@ -6,10 +6,8 @@ from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsC
import pyomo.environ as pe import pyomo.environ as pe
from pyomo.core import Var from pyomo.core import Var
from copy import deepcopy from copy import deepcopy
import pickle
from scipy.stats import randint from scipy.stats import randint
from p_tqdm import p_map from p_tqdm import p_map
import numpy as np
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,10 +34,14 @@ def _parallel_solve(instance_idx):
class InternalSolver: class InternalSolver:
def __init__(self): def __init__(self):
self.all_vars = None
self.instance = None
self.is_warm_start_available = False self.is_warm_start_available = False
self.model = None self.model = None
self.sense = None
self.solver = None
self.var_name_to_var = {} self.var_name_to_var = {}
def solve_lp(self, tee=False): def solve_lp(self, tee=False):
self.solver.set_instance(self.model) self.solver.set_instance(self.model)
@ -98,7 +100,7 @@ class InternalSolver:
(count_fixed, count_total)) (count_fixed, count_total))
def set_model(self, model): def set_model(self, model):
from pyomo.core.kernel.objective import minimize, maximize from pyomo.core.kernel.objective import minimize
self.model = model self.model = model
self.solver.set_instance(model) self.solver.set_instance(model)
if self.solver._objective.sense == minimize: if self.solver._objective.sense == minimize:
@ -176,6 +178,7 @@ class GurobiSolver(InternalSolver):
def solve(self, tee=False): def solve(self, tee=False):
from gurobipy import GRB from gurobipy import GRB
def cb(cb_model, cb_opt, cb_where): def cb(cb_model, cb_opt, cb_where):
if cb_where == GRB.Callback.MIPSOL: if cb_where == GRB.Callback.MIPSOL:
cb_opt.cbGetSolution(self.all_vars) cb_opt.cbGetSolution(self.all_vars)
@ -186,6 +189,7 @@ class GurobiSolver(InternalSolver):
for v in violations: for v in violations:
cut = self.instance.build_lazy_constraint(cb_model, v) cut = self.instance.build_lazy_constraint(cb_model, v)
cb_opt.cbLazy(cut) cb_opt.cbLazy(cut)
if hasattr(self.instance, "find_violations"): if hasattr(self.instance, "find_violations"):
self.solver.options["LazyConstraints"] = 1 self.solver.options["LazyConstraints"] = 1
self.solver.set_callback(cb) self.solver.set_callback(cb)
@ -204,7 +208,6 @@ class GurobiSolver(InternalSolver):
class CPLEXSolver(InternalSolver): class CPLEXSolver(InternalSolver):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
import cplex
self.solver = pe.SolverFactory('cplex_persistent') self.solver = pe.SolverFactory('cplex_persistent')
self.solver.options["randomseed"] = randint(low=0, high=1000).rvs() self.solver.options["randomseed"] = randint(low=0, high=1000).rvs()
@ -242,10 +245,8 @@ class LearningSolver:
mode="exact", mode="exact",
solver="gurobi", solver="gurobi",
threads=4, threads=4,
time_limit=None, time_limit=None):
):
self.is_persistent = None
self.components = {} self.components = {}
self.mode = mode self.mode = mode
self.internal_solver = None self.internal_solver = None
@ -286,11 +287,11 @@ class LearningSolver:
instance, instance,
model=None, model=None,
tee=False, tee=False,
relaxation_only=False, relaxation_only=False):
):
if model is None: if model is None:
model = instance.to_model() model = instance.to_model()
self.tee = tee self.tee = tee
self.internal_solver = self._create_internal_solver() self.internal_solver = self._create_internal_solver()
self.internal_solver.set_model(model) self.internal_solver.set_model(model)
@ -324,9 +325,7 @@ class LearningSolver:
def parallel_solve(self, def parallel_solve(self,
instances, instances,
n_jobs=4, n_jobs=4,
label="Solve", label="Solve"):
collect_training_data=True,
):
self.internal_solver = None self.internal_solver = None
SOLVER[0] = self SOLVER[0] = self
@ -356,3 +355,7 @@ class LearningSolver:
def add(self, component): def add(self, component):
name = component.__class__.__name__ name = component.__class__.__name__
self.components[name] = component self.components[name] = component
def __getstate__(self):
self.internal_solver = None
return self.__dict__

@ -42,8 +42,9 @@ def test_solver():
solver.fit([instance]) solver.fit([instance])
solver.solve(instance) solver.solve(instance)
# with tempfile.TemporaryFile() as file: # Assert solver is picklable
# pickle.dump(solver, file) with tempfile.TemporaryFile() as file:
pickle.dump(solver, file)
def test_parallel_solve(): def test_parallel_solve():

Loading…
Cancel
Save