Update benchmark scripts

master
Alinson S. Xavier 5 years ago
parent 203afc6993
commit a0062edb5a

@ -17,7 +17,6 @@ Options:
--train-time-limit=<n> Solver time limit during training in seconds [default: 3600] --train-time-limit=<n> Solver time limit during training in seconds [default: 3600]
--test-time-limit=<n> Solver time limit during test in seconds [default: 900] --test-time-limit=<n> Solver time limit during test in seconds [default: 900]
--solver-threads=<n> Number of threads the solver is allowed to use [default: 4] --solver-threads=<n> Number of threads the solver is allowed to use [default: 4]
--solver=<s> Internal MILP solver to use [default: gurobi]
""" """
import importlib import importlib
import logging import logging
@ -38,6 +37,7 @@ import seaborn as sns
from miplearn import ( from miplearn import (
LearningSolver, LearningSolver,
BenchmarkRunner, BenchmarkRunner,
GurobiPyomoSolver,
setup_logger, setup_logger,
) )
@ -52,7 +52,6 @@ n_jobs = int(args["--jobs"])
n_threads = int(args["--solver-threads"]) n_threads = int(args["--solver-threads"])
train_time_limit = int(args["--train-time-limit"]) train_time_limit = int(args["--train-time-limit"])
test_time_limit = int(args["--test-time-limit"]) test_time_limit = int(args["--test-time-limit"])
internal_solver = args["--solver"]
def write_pickle_gz(obj, filename): def write_pickle_gz(obj, filename):
@ -86,9 +85,12 @@ def train():
if not os.path.isfile(done_filename): if not os.path.isfile(done_filename):
train_instances = glob.glob(f"{basepath}/train/*.gz") train_instances = glob.glob(f"{basepath}/train/*.gz")
solver = LearningSolver( solver = LearningSolver(
time_limit=train_time_limit, solver=lambda: GurobiPyomoSolver(
solver=internal_solver, params={
threads=n_threads, "TimeLimit": train_time_limit,
"Threads": n_threads,
}
),
) )
solver.parallel_solve(train_instances, n_jobs=n_jobs) solver.parallel_solve(train_instances, n_jobs=n_jobs)
Path(done_filename).touch(exist_ok=True) Path(done_filename).touch(exist_ok=True)
@ -100,9 +102,12 @@ def test_baseline():
if not os.path.isfile(csv_filename): if not os.path.isfile(csv_filename):
solvers = { solvers = {
"baseline": LearningSolver( "baseline": LearningSolver(
time_limit=test_time_limit, solver=lambda: GurobiPyomoSolver(
solver=internal_solver, params={
threads=n_threads, "TimeLimit": train_time_limit,
"Threads": n_threads,
}
),
), ),
} }
benchmark = BenchmarkRunner(solvers) benchmark = BenchmarkRunner(solvers)
@ -117,14 +122,20 @@ def test_ml():
if not os.path.isfile(csv_filename): if not os.path.isfile(csv_filename):
solvers = { solvers = {
"ml-exact": LearningSolver( "ml-exact": LearningSolver(
time_limit=test_time_limit, solver=lambda: GurobiPyomoSolver(
solver=internal_solver, params={
threads=n_threads, "TimeLimit": train_time_limit,
"Threads": n_threads,
}
),
), ),
"ml-heuristic": LearningSolver( "ml-heuristic": LearningSolver(
time_limit=test_time_limit, solver=lambda: GurobiPyomoSolver(
solver=internal_solver, params={
threads=n_threads, "TimeLimit": train_time_limit,
"Threads": n_threads,
}
),
mode="heuristic", mode="heuristic",
), ),
} }

Loading…
Cancel
Save