mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
LearningSolver: add more constructor options; perform fit in parallel
This commit is contained in:
@@ -64,6 +64,7 @@ def train():
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
solver.parallel_solve(train_instances, n_jobs=10)
|
solver.parallel_solve(train_instances, n_jobs=10)
|
||||||
|
solver.fit(n_jobs=10)
|
||||||
solver.save_state("%s/training_data.bin" % basepath)
|
solver.save_state("%s/training_data.bin" % basepath)
|
||||||
save(train_instances, "%s/train_instances.bin" % basepath)
|
save(train_instances, "%s/train_instances.bin" % basepath)
|
||||||
save(test_instances, "%s/test_instances.bin" % basepath)
|
save(test_instances, "%s/test_instances.bin" % basepath)
|
||||||
@@ -103,7 +104,6 @@ def test_ml():
|
|||||||
test_instances = load("%s/test_instances.bin" % basepath)
|
test_instances = load("%s/test_instances.bin" % basepath)
|
||||||
benchmark = BenchmarkRunner(solvers)
|
benchmark = BenchmarkRunner(solvers)
|
||||||
benchmark.load_state("%s/training_data.bin" % basepath)
|
benchmark.load_state("%s/training_data.bin" % basepath)
|
||||||
benchmark.fit()
|
|
||||||
benchmark.load_results("%s/benchmark_baseline.csv" % basepath)
|
benchmark.load_results("%s/benchmark_baseline.csv" % basepath)
|
||||||
benchmark.parallel_solve(test_instances, n_jobs=10)
|
benchmark.parallel_solve(test_instances, n_jobs=10)
|
||||||
benchmark.save_results("%s/benchmark_ml.csv" % basepath)
|
benchmark.save_results("%s/benchmark_ml.csv" % basepath)
|
||||||
|
|||||||
@@ -9,11 +9,8 @@ import pyomo.environ as pe
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import pickle
|
import pickle
|
||||||
from tqdm import tqdm
|
|
||||||
from joblib import Parallel, delayed
|
|
||||||
from scipy.stats import randint
|
from scipy.stats import randint
|
||||||
import multiprocessing
|
from p_tqdm import p_map
|
||||||
|
|
||||||
|
|
||||||
def _gurobi_factory():
|
def _gurobi_factory():
|
||||||
solver = pe.SolverFactory('gurobi_persistent')
|
solver = pe.SolverFactory('gurobi_persistent')
|
||||||
@@ -29,7 +26,9 @@ class LearningSolver:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
threads=4,
|
threads=None,
|
||||||
|
time_limit=None,
|
||||||
|
gap_limit=None,
|
||||||
internal_solver_factory=_gurobi_factory,
|
internal_solver_factory=_gurobi_factory,
|
||||||
components=None,
|
components=None,
|
||||||
mode=None):
|
mode=None):
|
||||||
@@ -37,6 +36,9 @@ class LearningSolver:
|
|||||||
self.internal_solver = None
|
self.internal_solver = None
|
||||||
self.components = components
|
self.components = components
|
||||||
self.internal_solver_factory = internal_solver_factory
|
self.internal_solver_factory = internal_solver_factory
|
||||||
|
self.threads = threads
|
||||||
|
self.time_limit = time_limit
|
||||||
|
self.gap_limit = gap_limit
|
||||||
|
|
||||||
if self.components is not None:
|
if self.components is not None:
|
||||||
assert isinstance(self.components, dict)
|
assert isinstance(self.components, dict)
|
||||||
@@ -54,6 +56,12 @@ class LearningSolver:
|
|||||||
def _create_solver(self):
|
def _create_solver(self):
|
||||||
self.internal_solver = self.internal_solver_factory()
|
self.internal_solver = self.internal_solver_factory()
|
||||||
self.is_persistent = hasattr(self.internal_solver, "set_instance")
|
self.is_persistent = hasattr(self.internal_solver, "set_instance")
|
||||||
|
if self.threads is not None:
|
||||||
|
self.internal_solver.options["Threads"] = self.threads
|
||||||
|
if self.time_limit is not None:
|
||||||
|
self.internal_solver.options["TimeLimit"] = self.time_limit
|
||||||
|
if self.gap_limit is not None:
|
||||||
|
self.internal_solver.options["MIPGap"] = self.gap_limit
|
||||||
|
|
||||||
def solve(self, instance, tee=False):
|
def solve(self, instance, tee=False):
|
||||||
model = instance.to_model()
|
model = instance.to_model()
|
||||||
@@ -93,11 +101,7 @@ class LearningSolver:
|
|||||||
solver.components = {}
|
solver.components = {}
|
||||||
return solver, results
|
return solver, results
|
||||||
|
|
||||||
solver_result_pairs = Parallel(n_jobs=n_jobs)(
|
solver_result_pairs = p_map(_process, instances, num_cpus=n_jobs, desc=label)
|
||||||
delayed(_process)(instance)
|
|
||||||
for instance in tqdm(instances, desc=label, ncols=80)
|
|
||||||
)
|
|
||||||
|
|
||||||
subsolvers = [p[0] for p in solver_result_pairs]
|
subsolvers = [p[0] for p in solver_result_pairs]
|
||||||
results = [p[1] for p in solver_result_pairs]
|
results = [p[1] for p in solver_result_pairs]
|
||||||
|
|
||||||
@@ -109,9 +113,9 @@ class LearningSolver:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def fit(self):
|
def fit(self, n_jobs=1):
|
||||||
for component in self.components.values():
|
for component in self.components.values():
|
||||||
component.fit(self)
|
component.fit(self, n_jobs=n_jobs)
|
||||||
|
|
||||||
def save_state(self, filename):
|
def save_state(self, filename):
|
||||||
with open(filename, "wb") as file:
|
with open(filename, "wb") as file:
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def test_parallel_solve():
|
|||||||
def test_solver_random_branch_priority():
|
def test_solver_random_branch_priority():
|
||||||
instance = _get_instance()
|
instance = _get_instance()
|
||||||
components = {
|
components = {
|
||||||
"warm-start": BranchPriorityComponent(initial_priority=np.array([1., 2., 3., 4.])),
|
"warm-start": BranchPriorityComponent(),
|
||||||
}
|
}
|
||||||
solver = LearningSolver(components=components)
|
solver = LearningSolver(components=components)
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
docopt
|
docopt
|
||||||
|
matplotlib
|
||||||
mkdocs
|
mkdocs
|
||||||
mkdocs-cinder
|
mkdocs-cinder
|
||||||
networkx
|
networkx
|
||||||
numpy
|
numpy
|
||||||
|
p_tqdm
|
||||||
pandas
|
pandas
|
||||||
pyomo
|
pyomo
|
||||||
pytest
|
pytest
|
||||||
pytest-watch
|
pytest-watch
|
||||||
python-markdown-math
|
python-markdown-math
|
||||||
|
seaborn
|
||||||
sklearn
|
sklearn
|
||||||
tqdm
|
tqdm
|
||||||
|
|||||||
Reference in New Issue
Block a user