mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
LearningSolver: add more constructor options; perform fit in parallel
This commit is contained in:
@@ -9,11 +9,8 @@ import pyomo.environ as pe
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
from joblib import Parallel, delayed
|
||||
from scipy.stats import randint
|
||||
import multiprocessing
|
||||
|
||||
from p_tqdm import p_map
|
||||
|
||||
def _gurobi_factory():
|
||||
solver = pe.SolverFactory('gurobi_persistent')
|
||||
@@ -29,7 +26,9 @@ class LearningSolver:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
threads=4,
|
||||
threads=None,
|
||||
time_limit=None,
|
||||
gap_limit=None,
|
||||
internal_solver_factory=_gurobi_factory,
|
||||
components=None,
|
||||
mode=None):
|
||||
@@ -37,6 +36,9 @@ class LearningSolver:
|
||||
self.internal_solver = None
|
||||
self.components = components
|
||||
self.internal_solver_factory = internal_solver_factory
|
||||
self.threads = threads
|
||||
self.time_limit = time_limit
|
||||
self.gap_limit = gap_limit
|
||||
|
||||
if self.components is not None:
|
||||
assert isinstance(self.components, dict)
|
||||
@@ -54,6 +56,12 @@ class LearningSolver:
|
||||
def _create_solver(self):
|
||||
self.internal_solver = self.internal_solver_factory()
|
||||
self.is_persistent = hasattr(self.internal_solver, "set_instance")
|
||||
if self.threads is not None:
|
||||
self.internal_solver.options["Threads"] = self.threads
|
||||
if self.time_limit is not None:
|
||||
self.internal_solver.options["TimeLimit"] = self.time_limit
|
||||
if self.gap_limit is not None:
|
||||
self.internal_solver.options["MIPGap"] = self.gap_limit
|
||||
|
||||
def solve(self, instance, tee=False):
|
||||
model = instance.to_model()
|
||||
@@ -93,11 +101,7 @@ class LearningSolver:
|
||||
solver.components = {}
|
||||
return solver, results
|
||||
|
||||
solver_result_pairs = Parallel(n_jobs=n_jobs)(
|
||||
delayed(_process)(instance)
|
||||
for instance in tqdm(instances, desc=label, ncols=80)
|
||||
)
|
||||
|
||||
solver_result_pairs = p_map(_process, instances, num_cpus=n_jobs, desc=label)
|
||||
subsolvers = [p[0] for p in solver_result_pairs]
|
||||
results = [p[1] for p in solver_result_pairs]
|
||||
|
||||
@@ -109,9 +113,9 @@ class LearningSolver:
|
||||
|
||||
return results
|
||||
|
||||
def fit(self):
|
||||
def fit(self, n_jobs=1):
|
||||
for component in self.components.values():
|
||||
component.fit(self)
|
||||
component.fit(self, n_jobs=n_jobs)
|
||||
|
||||
def save_state(self, filename):
|
||||
with open(filename, "wb") as file:
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_parallel_solve():
|
||||
def test_solver_random_branch_priority():
|
||||
instance = _get_instance()
|
||||
components = {
|
||||
"warm-start": BranchPriorityComponent(initial_priority=np.array([1., 2., 3., 4.])),
|
||||
"warm-start": BranchPriorityComponent(),
|
||||
}
|
||||
solver = LearningSolver(components=components)
|
||||
solver.solve(instance)
|
||||
|
||||
Reference in New Issue
Block a user