mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 18:38:51 -06:00
Update BranchPriorityComponent; activate it during benchmark
This commit is contained in:
@@ -16,12 +16,12 @@ Options:
|
||||
"""
|
||||
from docopt import docopt
|
||||
import importlib, pathlib
|
||||
from miplearn import (LearningSolver, BenchmarkRunner)
|
||||
from miplearn import LearningSolver, BenchmarkRunner, BranchPriorityComponent
|
||||
from numpy import median
|
||||
import pyomo.environ as pe
|
||||
import pickle
|
||||
import logging
|
||||
import sys
|
||||
import multiprocessing
|
||||
|
||||
logging.basicConfig(format='%(asctime)s %(levelname).1s %(name)s: %(message)12s',
|
||||
datefmt='%H:%M:%S',
|
||||
@@ -32,9 +32,12 @@ logging.getLogger('pyomo.core').setLevel(logging.ERROR)
|
||||
logging.getLogger('miplearn').setLevel(logging.INFO)
|
||||
logger = logging.getLogger("benchmark")
|
||||
|
||||
n_jobs = 10
|
||||
test_time_limit = 3600
|
||||
train_time_limit = 900
|
||||
n_jobs = multiprocessing.cpu_count() // 4
|
||||
logger.info("Running %d jobs in parallel" % n_jobs)
|
||||
|
||||
train_time_limit = 3600
|
||||
test_time_limit = 900
|
||||
test_node_limit = 1_000_000
|
||||
internal_solver = "gurobi"
|
||||
|
||||
args = docopt(__doc__)
|
||||
@@ -46,41 +49,43 @@ def save(obj, filename):
|
||||
logger.info("Writing %s..." % filename)
|
||||
with open(filename, "wb") as file:
|
||||
pickle.dump(obj, file)
|
||||
|
||||
|
||||
|
||||
|
||||
def load(filename):
|
||||
import pickle
|
||||
with open(filename, "rb") as file:
|
||||
return pickle.load(file)
|
||||
|
||||
|
||||
return pickle.load(file)
|
||||
|
||||
|
||||
def train():
|
||||
problem_name, challenge_name = args["<challenge>"].split("/")
|
||||
pkg = importlib.import_module("miplearn.problems.%s" % problem_name)
|
||||
challenge = getattr(pkg, challenge_name)()
|
||||
train_instances = challenge.training_instances
|
||||
test_instances = challenge.test_instances
|
||||
test_instances = challenge.test_instances
|
||||
solver = LearningSolver(time_limit=train_time_limit,
|
||||
solver=internal_solver,
|
||||
components={})
|
||||
solver.parallel_solve(train_instances, n_jobs=n_jobs)
|
||||
solver=internal_solver)
|
||||
solver.add(BranchPriorityComponent())
|
||||
solver.parallel_solve(train_instances[:1], n_jobs=n_jobs)
|
||||
solver.fit(train_instances[:1])
|
||||
save(train_instances, "%s/train_instances.bin" % basepath)
|
||||
save(test_instances, "%s/test_instances.bin" % basepath)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_baseline():
|
||||
test_instances = load("%s/test_instances.bin" % basepath)
|
||||
solvers = {
|
||||
"baseline": LearningSolver(
|
||||
time_limit=test_time_limit,
|
||||
node_limit=test_node_limit,
|
||||
solver=internal_solver,
|
||||
),
|
||||
}
|
||||
benchmark = BenchmarkRunner(solvers)
|
||||
benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
|
||||
benchmark.save_results("%s/benchmark_baseline.csv" % basepath)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_ml():
|
||||
logger.info("Loading instances...")
|
||||
train_instances = load("%s/train_instances.bin" % basepath)
|
||||
@@ -88,14 +93,17 @@ def test_ml():
|
||||
solvers = {
|
||||
"ml-exact": LearningSolver(
|
||||
time_limit=test_time_limit,
|
||||
node_limit=test_node_limit,
|
||||
solver=internal_solver,
|
||||
),
|
||||
"ml-heuristic": LearningSolver(
|
||||
time_limit=test_time_limit,
|
||||
node_limit=test_node_limit,
|
||||
solver=internal_solver,
|
||||
mode="heuristic",
|
||||
),
|
||||
}
|
||||
solvers["ml-exact"].add(BranchPriorityComponent())
|
||||
benchmark = BenchmarkRunner(solvers)
|
||||
logger.info("Loading results...")
|
||||
benchmark.load_results("%s/benchmark_baseline.csv" % basepath)
|
||||
@@ -105,7 +113,7 @@ def test_ml():
|
||||
benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
|
||||
benchmark.save_results("%s/benchmark_ml.csv" % basepath)
|
||||
|
||||
|
||||
|
||||
def charts():
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
@@ -115,7 +123,7 @@ def charts():
|
||||
benchmark.load_results("%s/benchmark_ml.csv" % basepath)
|
||||
results = benchmark.raw_results()
|
||||
results["Gap (%)"] = results["Gap"] * 100.0
|
||||
|
||||
|
||||
sense = results.loc[0, "Sense"]
|
||||
if sense == "min":
|
||||
primal_column = "Relative Upper Bound"
|
||||
@@ -125,17 +133,17 @@ def charts():
|
||||
primal_column = "Relative Lower Bound"
|
||||
obj_column = "Lower Bound"
|
||||
predicted_obj_column = "Predicted LB"
|
||||
|
||||
palette={
|
||||
"baseline": "#9b59b6",
|
||||
|
||||
palette = {
|
||||
"baseline": "#9b59b6",
|
||||
"ml-exact": "#3498db",
|
||||
"ml-heuristic": "#95a5a6"
|
||||
}
|
||||
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1,
|
||||
ncols=4,
|
||||
figsize=(12,4),
|
||||
figsize=(12, 4),
|
||||
gridspec_kw={'width_ratios': [2, 1, 1, 2]},
|
||||
)
|
||||
)
|
||||
sns.stripplot(x="Solver",
|
||||
y="Wallclock Time",
|
||||
data=results,
|
||||
@@ -143,7 +151,7 @@ def charts():
|
||||
jitter=0.25,
|
||||
palette=palette,
|
||||
size=4.0,
|
||||
);
|
||||
);
|
||||
sns.barplot(x="Solver",
|
||||
y="Wallclock Time",
|
||||
data=results,
|
||||
@@ -152,7 +160,7 @@ def charts():
|
||||
alpha=0.4,
|
||||
palette=palette,
|
||||
estimator=median,
|
||||
);
|
||||
);
|
||||
ax1.set(ylabel='Wallclock Time (s)')
|
||||
ax2.set_ylim(-0.5, 5.5)
|
||||
sns.stripplot(x="Solver",
|
||||
@@ -162,34 +170,35 @@ def charts():
|
||||
ax=ax2,
|
||||
palette=palette,
|
||||
size=4.0,
|
||||
);
|
||||
ax3.set_ylim(0.95,1.05)
|
||||
);
|
||||
# ax3.set_ylim(0.95,1.05)
|
||||
sns.stripplot(x="Solver",
|
||||
y=primal_column,
|
||||
jitter=0.25,
|
||||
data=results[results["Solver"] == "ml-heuristic"],
|
||||
ax=ax3,
|
||||
palette=palette,
|
||||
);
|
||||
|
||||
);
|
||||
|
||||
sns.scatterplot(x=obj_column,
|
||||
y=predicted_obj_column,
|
||||
hue="Solver",
|
||||
data=results[results["Solver"] == "ml-exact"],
|
||||
ax=ax4,
|
||||
palette=palette,
|
||||
);
|
||||
);
|
||||
xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
|
||||
ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls='-', color="#cccccc");
|
||||
ax4.set_xlim(xlim)
|
||||
ax4.set_ylim(ylim)
|
||||
ax4.get_legend().remove()
|
||||
|
||||
|
||||
fig.tight_layout()
|
||||
plt.savefig("%s/performance.png" % basepath,
|
||||
bbox_inches='tight',
|
||||
dpi=150)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if args["train"]:
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user