diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index d8c9a31..1962c16 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -13,8 +13,9 @@ Usage: Options: -h --help Show this screen - --jobs= Number of instances to solve simultaneously [default: 5] + --train-jobs= Number of instances to solve in parallel during training [default: 10] --train-time-limit= Solver time limit during training in seconds [default: 3600] + --test-jobs= Number of instances to solve in parallel during test [default: 5] --test-time-limit= Solver time limit during test in seconds [default: 900] --solver-threads= Number of threads the solver is allowed to use [default: 4] """ @@ -46,13 +47,6 @@ logging.getLogger("gurobipy").setLevel(logging.ERROR) logging.getLogger("pyomo.core").setLevel(logging.ERROR) logger = logging.getLogger("benchmark") -args = docopt(__doc__) -basepath = args[""] -n_jobs = int(args["--jobs"]) -n_threads = int(args["--solver-threads"]) -train_time_limit = int(args["--train-time-limit"]) -test_time_limit = int(args["--test-time-limit"]) - def write_pickle_gz(obj, filename): logger.info(f"Writing: {filename}") @@ -72,7 +66,8 @@ def write_multiple(objs, dirname): write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz") -def train(): +def train(args): + basepath = args[""] problem_name, challenge_name = args[""].split("/") pkg = importlib.import_module(f"miplearn.problems.{problem_name}") challenge = getattr(pkg, challenge_name)() @@ -87,16 +82,20 @@ def train(): solver = LearningSolver( solver=lambda: GurobiPyomoSolver( params={ - "TimeLimit": train_time_limit, - "Threads": n_threads, + "TimeLimit": int(args["--train-time-limit"]), + "Threads": int(args["--solver-threads"]), } ), ) - solver.parallel_solve(train_instances, n_jobs=n_jobs) + solver.parallel_solve( + train_instances, + n_jobs=int(args["--train-jobs"]), + ) Path(done_filename).touch(exist_ok=True) -def test_baseline(): +def test_baseline(args): + basepath = args[""] test_instances = glob.glob(f"{basepath}/test/*.gz") csv_filename = f"{basepath}/benchmark_baseline.csv" if not os.path.isfile(csv_filename): @@ -104,18 +103,22 @@ def test_baseline(): "baseline": LearningSolver( solver=lambda: GurobiPyomoSolver( params={ - "TimeLimit": train_time_limit, - "Threads": n_threads, + "TimeLimit": int(args["--test-time-limit"]), + "Threads": int(args["--solver-threads"]), } ), ), } benchmark = BenchmarkRunner(solvers) - benchmark.parallel_solve(test_instances, n_jobs=n_jobs) + benchmark.parallel_solve( + test_instances, + n_jobs=int(args["--test-jobs"]), + ) benchmark.save_results(csv_filename) -def test_ml(): +def test_ml(args): + basepath = args[""] test_instances = glob.glob(f"{basepath}/test/*.gz") train_instances = glob.glob(f"{basepath}/train/*.gz") csv_filename = f"{basepath}/benchmark_ml.csv" @@ -124,16 +127,16 @@ def test_ml(): "ml-exact": LearningSolver( solver=lambda: GurobiPyomoSolver( params={ - "TimeLimit": train_time_limit, - "Threads": n_threads, + "TimeLimit": int(args["--test-time-limit"]), + "Threads": int(args["--solver-threads"]), } ), ), "ml-heuristic": LearningSolver( solver=lambda: GurobiPyomoSolver( params={ - "TimeLimit": train_time_limit, - "Threads": n_threads, + "TimeLimit": int(args["--test-time-limit"]), + "Threads": int(args["--solver-threads"]), } ), mode="heuristic", @@ -141,11 +144,15 @@ def test_ml(): } benchmark = BenchmarkRunner(solvers) benchmark.fit(train_instances) - benchmark.parallel_solve(test_instances, n_jobs=n_jobs) + benchmark.parallel_solve( + test_instances, + n_jobs=int(args["--test-jobs"]), + ) benchmark.save_results(csv_filename) -def charts(): +def charts(args): + basepath = args[""] sns.set_style("whitegrid") sns.set_palette("Blues_r") @@ -256,11 +263,12 @@ def charts(): if __name__ == "__main__": + args = docopt(__doc__) if args["train"]: - train() + train(args) if args["test-baseline"]: - test_baseline() + test_baseline(args) if args["test-ml"]: - test_ml() + test_ml(args) if args["charts"]: - charts() + charts(args)