mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make package work with persistent solvers; update README.md
This commit is contained in:
@@ -13,6 +13,10 @@ from joblib import Parallel, delayed
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def _gurobi_factory():
|
||||
solver = pe.SolverFactory('gurobi_persistent')
|
||||
solver.options["threads"] = 4
|
||||
return solver
|
||||
|
||||
class LearningSolver:
|
||||
"""
|
||||
@@ -22,11 +26,11 @@ class LearningSolver:
|
||||
|
||||
def __init__(self,
|
||||
threads=4,
|
||||
parent_solver=pe.SolverFactory('gurobi'),
|
||||
internal_solver_factory=_gurobi_factory,
|
||||
ws_predictor=KnnWarmStartPredictor(),
|
||||
mode="exact"):
|
||||
self.parent_solver = parent_solver
|
||||
self.parent_solver.options["threads"] = threads
|
||||
self.internal_solver_factory = internal_solver_factory
|
||||
self.internal_solver = self.internal_solver_factory()
|
||||
self.mode = mode
|
||||
self.x_train = {}
|
||||
self.y_train = {}
|
||||
@@ -86,8 +90,11 @@ class LearningSolver:
|
||||
return solve_results
|
||||
|
||||
def parallel_solve(self, instances, n_jobs=4, label="Solve"):
|
||||
self.parentSolver = None
|
||||
|
||||
def _process(instance):
|
||||
solver = deepcopy(self)
|
||||
solver = copy(self)
|
||||
solver.internal_solver = solver.internal_solver_factory()
|
||||
results = solver.solve(instance)
|
||||
return {
|
||||
"x_train": solver.x_train,
|
||||
@@ -143,8 +150,8 @@ class LearningSolver:
|
||||
self.ws_predictors = self.ws_predictors
|
||||
|
||||
def _solve(self, model, tee=False):
|
||||
if hasattr(self.parent_solver, "set_instance"):
|
||||
self.parent_solver.set_instance(model)
|
||||
return self.parent_solver.solve(tee=tee, warmstart=True)
|
||||
if hasattr(self.internal_solver, "set_instance"):
|
||||
self.internal_solver.set_instance(model)
|
||||
return self.internal_solver.solve(tee=tee, warmstart=True)
|
||||
else:
|
||||
return self.parent_solver.solve(model, tee=tee, warmstart=True)
|
||||
return self.internal_solver.solve(model, tee=tee, warmstart=True)
|
||||
|
||||
Reference in New Issue
Block a user