mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Benchmark: Avoid loading instances to memory
This commit is contained in:
@@ -13,7 +13,7 @@ Usage:
|
||||
|
||||
Options:
|
||||
-h --help Show this screen
|
||||
--jobs=<n> Number of instances to solve simultaneously [default: 10]
|
||||
--jobs=<n> Number of instances to solve simultaneously [default: 5]
|
||||
--train-time-limit=<n> Solver time limit during training in seconds [default: 3600]
|
||||
--test-time-limit=<n> Solver time limit during test in seconds [default: 900]
|
||||
--solver-threads=<n> Number of threads the solver is allowed to use [default: 4]
|
||||
@@ -24,27 +24,27 @@ import logging
|
||||
import pathlib
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
import gzip
|
||||
import glob
|
||||
|
||||
from docopt import docopt
|
||||
from numpy import median
|
||||
from pathlib import Path
|
||||
|
||||
from miplearn import LearningSolver, BenchmarkRunner
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s %(levelname).1s %(name)s: %(message)12s",
|
||||
datefmt="%H:%M:%S",
|
||||
level=logging.INFO,
|
||||
stream=sys.stdout,
|
||||
from miplearn import (
|
||||
LearningSolver,
|
||||
BenchmarkRunner,
|
||||
setup_logger,
|
||||
)
|
||||
|
||||
setup_logger()
|
||||
logging.getLogger("gurobipy").setLevel(logging.ERROR)
|
||||
logging.getLogger("pyomo.core").setLevel(logging.ERROR)
|
||||
logging.getLogger("miplearn").setLevel(logging.INFO)
|
||||
logger = logging.getLogger("benchmark")
|
||||
|
||||
args = docopt(__doc__)
|
||||
basepath = args["<challenge>"]
|
||||
pathlib.Path(basepath).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
n_jobs = int(args["--jobs"])
|
||||
n_threads = int(args["--solver-threads"])
|
||||
train_time_limit = int(args["--train-time-limit"])
|
||||
@@ -52,37 +52,49 @@ test_time_limit = int(args["--test-time-limit"])
|
||||
internal_solver = args["--solver"]
|
||||
|
||||
|
||||
def save(obj, filename):
|
||||
logger.info("Writing %s..." % filename)
|
||||
with open(filename, "wb") as file:
|
||||
def write_pickle_gz(obj, filename):
|
||||
logger.info(f"Writing: {filename}")
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with gzip.GzipFile(filename, "wb") as file:
|
||||
pickle.dump(obj, file)
|
||||
|
||||
|
||||
def load(filename):
|
||||
import pickle
|
||||
|
||||
with open(filename, "rb") as file:
|
||||
def read_pickle_gz(filename):
|
||||
logger.info(f"Reading: {filename}")
|
||||
with gzip.GzipFile(filename, "rb") as file:
|
||||
return pickle.load(file)
|
||||
|
||||
|
||||
def write_multiple(objs, dirname):
|
||||
for (i, obj) in enumerate(objs):
|
||||
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
||||
|
||||
|
||||
def train():
|
||||
problem_name, challenge_name = args["<challenge>"].split("/")
|
||||
pkg = importlib.import_module("miplearn.problems.%s" % problem_name)
|
||||
pkg = importlib.import_module(f"miplearn.problems.{problem_name}")
|
||||
challenge = getattr(pkg, challenge_name)()
|
||||
train_instances = challenge.training_instances
|
||||
test_instances = challenge.test_instances
|
||||
|
||||
if not os.path.isdir(f"{basepath}/train"):
|
||||
write_multiple(challenge.training_instances, f"{basepath}/train")
|
||||
write_multiple(challenge.test_instances, f"{basepath}/test")
|
||||
|
||||
done_filename = f"{basepath}/train/done"
|
||||
if not os.path.isfile(done_filename):
|
||||
train_instances = glob.glob(f"{basepath}/train/*.gz")
|
||||
solver = LearningSolver(
|
||||
time_limit=train_time_limit,
|
||||
solver=internal_solver,
|
||||
threads=n_threads,
|
||||
)
|
||||
solver.parallel_solve(train_instances, n_jobs=n_jobs)
|
||||
save(train_instances, "%s/train_instances.bin" % basepath)
|
||||
save(test_instances, "%s/test_instances.bin" % basepath)
|
||||
Path(done_filename).touch(exist_ok=True)
|
||||
|
||||
|
||||
def test_baseline():
|
||||
test_instances = load("%s/test_instances.bin" % basepath)
|
||||
test_instances = glob.glob(f"{basepath}/test/*.gz")
|
||||
csv_filename = f"{basepath}/benchmark_baseline.csv"
|
||||
if not os.path.isfile(csv_filename):
|
||||
solvers = {
|
||||
"baseline": LearningSolver(
|
||||
time_limit=test_time_limit,
|
||||
@@ -92,13 +104,14 @@ def test_baseline():
|
||||
}
|
||||
benchmark = BenchmarkRunner(solvers)
|
||||
benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
|
||||
benchmark.save_results("%s/benchmark_baseline.csv" % basepath)
|
||||
benchmark.save_results(csv_filename)
|
||||
|
||||
|
||||
def test_ml():
|
||||
logger.info("Loading instances...")
|
||||
train_instances = load("%s/train_instances.bin" % basepath)
|
||||
test_instances = load("%s/test_instances.bin" % basepath)
|
||||
test_instances = glob.glob(f"{basepath}/test/*.gz")
|
||||
train_instances = glob.glob(f"{basepath}/train/*.gz")
|
||||
csv_filename = f"{basepath}/benchmark_ml.csv"
|
||||
if not os.path.isfile(csv_filename):
|
||||
solvers = {
|
||||
"ml-exact": LearningSolver(
|
||||
time_limit=test_time_limit,
|
||||
@@ -113,13 +126,9 @@ def test_ml():
|
||||
),
|
||||
}
|
||||
benchmark = BenchmarkRunner(solvers)
|
||||
logger.info("Loading results...")
|
||||
benchmark.load_results("%s/benchmark_baseline.csv" % basepath)
|
||||
logger.info("Fitting...")
|
||||
benchmark.fit(train_instances)
|
||||
logger.info("Solving...")
|
||||
benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
|
||||
benchmark.save_results("%s/benchmark_ml.csv" % basepath)
|
||||
benchmark.save_results(csv_filename)
|
||||
|
||||
|
||||
def charts():
|
||||
@@ -129,18 +138,19 @@ def charts():
|
||||
sns.set_style("whitegrid")
|
||||
sns.set_palette("Blues_r")
|
||||
benchmark = BenchmarkRunner({})
|
||||
benchmark.load_results("%s/benchmark_ml.csv" % basepath)
|
||||
benchmark.load_results(f"{basepath}/benchmark_baseline.csv")
|
||||
benchmark.load_results(f"{basepath}/benchmark_ml.csv")
|
||||
results = benchmark.raw_results()
|
||||
results["Gap (%)"] = results["Gap"] * 100.0
|
||||
|
||||
sense = results.loc[0, "Sense"]
|
||||
if sense == "min":
|
||||
primal_column = "Relative Upper Bound"
|
||||
obj_column = "Upper Bound"
|
||||
if (sense == "min").any():
|
||||
primal_column = "Relative upper bound"
|
||||
obj_column = "Upper bound"
|
||||
predicted_obj_column = "Predicted UB"
|
||||
else:
|
||||
primal_column = "Relative Lower Bound"
|
||||
obj_column = "Lower Bound"
|
||||
primal_column = "Relative lower bound"
|
||||
obj_column = "Lower bound"
|
||||
predicted_obj_column = "Predicted LB"
|
||||
|
||||
palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
|
||||
@@ -150,9 +160,11 @@ def charts():
|
||||
figsize=(12, 4),
|
||||
gridspec_kw={"width_ratios": [2, 1, 1, 2]},
|
||||
)
|
||||
|
||||
# Wallclock time
|
||||
sns.stripplot(
|
||||
x="Solver",
|
||||
y="Wallclock Time",
|
||||
y="Wallclock time",
|
||||
data=results,
|
||||
ax=ax1,
|
||||
jitter=0.25,
|
||||
@@ -161,7 +173,7 @@ def charts():
|
||||
)
|
||||
sns.barplot(
|
||||
x="Solver",
|
||||
y="Wallclock Time",
|
||||
y="Wallclock time",
|
||||
data=results,
|
||||
ax=ax1,
|
||||
errwidth=0.0,
|
||||
@@ -169,7 +181,9 @@ def charts():
|
||||
palette=palette,
|
||||
estimator=median,
|
||||
)
|
||||
ax1.set(ylabel="Wallclock Time (s)")
|
||||
ax1.set(ylabel="Wallclock time (s)")
|
||||
|
||||
# Gap
|
||||
ax2.set_ylim(-0.5, 5.5)
|
||||
sns.stripplot(
|
||||
x="Solver",
|
||||
@@ -180,6 +194,8 @@ def charts():
|
||||
palette=palette,
|
||||
size=4.0,
|
||||
)
|
||||
|
||||
# Relative primal bound
|
||||
ax3.set_ylim(0.95, 1.05)
|
||||
sns.stripplot(
|
||||
x="Solver",
|
||||
@@ -189,7 +205,6 @@ def charts():
|
||||
ax=ax3,
|
||||
palette=palette,
|
||||
)
|
||||
|
||||
sns.scatterplot(
|
||||
x=obj_column,
|
||||
y=predicted_obj_column,
|
||||
@@ -198,14 +213,29 @@ def charts():
|
||||
ax=ax4,
|
||||
palette=palette,
|
||||
)
|
||||
|
||||
# Predicted vs actual primal bound
|
||||
xlim, ylim = ax4.get_xlim(), ax4.get_ylim()
|
||||
ax4.plot([-1e10, 1e10], [-1e10, 1e10], ls="-", color="#cccccc")
|
||||
ax4.plot(
|
||||
[-1e10, 1e10],
|
||||
[-1e10, 1e10],
|
||||
ls="-",
|
||||
color="#cccccc",
|
||||
)
|
||||
ax4.set_xlim(xlim)
|
||||
ax4.set_ylim(ylim)
|
||||
ax4.get_legend().remove()
|
||||
ax4.set(
|
||||
ylabel="Predicted value",
|
||||
xlabel="Actual value",
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
plt.savefig("%s/performance.png" % basepath, bbox_inches="tight", dpi=150)
|
||||
plt.savefig(
|
||||
f"{basepath}/performance.png",
|
||||
bbox_inches="tight",
|
||||
dpi=150,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user