mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Update benchmark script
This commit is contained in:
@@ -13,8 +13,9 @@ Usage:
|
||||
|
||||
Options:
|
||||
-h --help Show this screen
|
||||
--jobs=<n> Number of instances to solve simultaneously [default: 5]
|
||||
--train-jobs=<n> Number of instances to solve in parallel during training [default: 10]
|
||||
--train-time-limit=<n> Solver time limit during training in seconds [default: 3600]
|
||||
--test-jobs=<n> Number of instances to solve in parallel during test [default: 5]
|
||||
--test-time-limit=<n> Solver time limit during test in seconds [default: 900]
|
||||
--solver-threads=<n> Number of threads the solver is allowed to use [default: 4]
|
||||
"""
|
||||
@@ -46,13 +47,6 @@ logging.getLogger("gurobipy").setLevel(logging.ERROR)
|
||||
logging.getLogger("pyomo.core").setLevel(logging.ERROR)
|
||||
logger = logging.getLogger("benchmark")
|
||||
|
||||
args = docopt(__doc__)
|
||||
basepath = args["<challenge>"]
|
||||
n_jobs = int(args["--jobs"])
|
||||
n_threads = int(args["--solver-threads"])
|
||||
train_time_limit = int(args["--train-time-limit"])
|
||||
test_time_limit = int(args["--test-time-limit"])
|
||||
|
||||
|
||||
def write_pickle_gz(obj, filename):
|
||||
logger.info(f"Writing: {filename}")
|
||||
@@ -72,7 +66,8 @@ def write_multiple(objs, dirname):
|
||||
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
||||
|
||||
|
||||
def train():
|
||||
def train(args):
|
||||
basepath = args["<challenge>"]
|
||||
problem_name, challenge_name = args["<challenge>"].split("/")
|
||||
pkg = importlib.import_module(f"miplearn.problems.{problem_name}")
|
||||
challenge = getattr(pkg, challenge_name)()
|
||||
@@ -87,16 +82,20 @@ def train():
|
||||
solver = LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": train_time_limit,
|
||||
"Threads": n_threads,
|
||||
"TimeLimit": int(args["--train-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
}
|
||||
),
|
||||
)
|
||||
solver.parallel_solve(train_instances, n_jobs=n_jobs)
|
||||
solver.parallel_solve(
|
||||
train_instances,
|
||||
n_jobs=int(args["--train-jobs"]),
|
||||
)
|
||||
Path(done_filename).touch(exist_ok=True)
|
||||
|
||||
|
||||
def test_baseline():
|
||||
def test_baseline(args):
|
||||
basepath = args["<challenge>"]
|
||||
test_instances = glob.glob(f"{basepath}/test/*.gz")
|
||||
csv_filename = f"{basepath}/benchmark_baseline.csv"
|
||||
if not os.path.isfile(csv_filename):
|
||||
@@ -104,18 +103,22 @@ def test_baseline():
|
||||
"baseline": LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": train_time_limit,
|
||||
"Threads": n_threads,
|
||||
"TimeLimit": int(args["--test-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
}
|
||||
),
|
||||
),
|
||||
}
|
||||
benchmark = BenchmarkRunner(solvers)
|
||||
benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
|
||||
benchmark.parallel_solve(
|
||||
test_instances,
|
||||
n_jobs=int(args["--test-jobs"]),
|
||||
)
|
||||
benchmark.save_results(csv_filename)
|
||||
|
||||
|
||||
def test_ml():
|
||||
def test_ml(args):
|
||||
basepath = args["<challenge>"]
|
||||
test_instances = glob.glob(f"{basepath}/test/*.gz")
|
||||
train_instances = glob.glob(f"{basepath}/train/*.gz")
|
||||
csv_filename = f"{basepath}/benchmark_ml.csv"
|
||||
@@ -124,16 +127,16 @@ def test_ml():
|
||||
"ml-exact": LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": train_time_limit,
|
||||
"Threads": n_threads,
|
||||
"TimeLimit": int(args["--test-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
}
|
||||
),
|
||||
),
|
||||
"ml-heuristic": LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": train_time_limit,
|
||||
"Threads": n_threads,
|
||||
"TimeLimit": int(args["--test-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
}
|
||||
),
|
||||
mode="heuristic",
|
||||
@@ -141,11 +144,15 @@ def test_ml():
|
||||
}
|
||||
benchmark = BenchmarkRunner(solvers)
|
||||
benchmark.fit(train_instances)
|
||||
benchmark.parallel_solve(test_instances, n_jobs=n_jobs)
|
||||
benchmark.parallel_solve(
|
||||
test_instances,
|
||||
n_jobs=int(args["--test-jobs"]),
|
||||
)
|
||||
benchmark.save_results(csv_filename)
|
||||
|
||||
|
||||
def charts():
|
||||
def charts(args):
|
||||
basepath = args["<challenge>"]
|
||||
sns.set_style("whitegrid")
|
||||
sns.set_palette("Blues_r")
|
||||
|
||||
@@ -256,11 +263,12 @@ def charts():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = docopt(__doc__)
|
||||
if args["train"]:
|
||||
train()
|
||||
train(args)
|
||||
if args["test-baseline"]:
|
||||
test_baseline()
|
||||
test_baseline(args)
|
||||
if args["test-ml"]:
|
||||
test_ml()
|
||||
test_ml(args)
|
||||
if args["charts"]:
|
||||
charts()
|
||||
charts(args)
|
||||
|
||||
Reference in New Issue
Block a user