Fix benchmark scripts; add more input checks

This commit is contained in:
2021-04-03 07:57:22 -05:00
parent 0bce2051a8
commit 7a6b31ca9a
5 changed files with 49 additions and 32 deletions

View File

@@ -61,7 +61,7 @@ def read_pickle_gz(filename):
return pickle.load(file)
def write_multiple(objs, dirname):
def write_pickle_gz_multiple(objs, dirname):
for (i, obj) in enumerate(objs):
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
@@ -73,8 +73,8 @@ def train(args):
challenge = getattr(pkg, challenge_name)()
if not os.path.isdir(f"{basepath}/train"):
write_multiple(challenge.training_instances, f"{basepath}/train")
write_multiple(challenge.test_instances, f"{basepath}/test")
write_pickle_gz_multiple(challenge.training_instances, f"{basepath}/train")
write_pickle_gz_multiple(challenge.test_instances, f"{basepath}/test")
done_filename = f"{basepath}/train/done"
if not os.path.isfile(done_filename):
@@ -114,7 +114,7 @@ def test_baseline(args):
test_instances,
n_jobs=int(args["--test-jobs"]),
)
benchmark.save_results(csv_filename)
benchmark.write_csv(csv_filename)
def test_ml(args):
@@ -148,7 +148,7 @@ def test_ml(args):
test_instances,
n_jobs=int(args["--test-jobs"]),
)
benchmark.save_results(csv_filename)
benchmark.write_csv(csv_filename)
def charts(args):
@@ -171,11 +171,11 @@ def charts(args):
if (sense == "min").any():
primal_column = "Relative upper bound"
obj_column = "Upper bound"
predicted_obj_column = "Predicted UB"
predicted_obj_column = "Objective: Predicted UB"
else:
primal_column = "Relative lower bound"
obj_column = "Lower bound"
predicted_obj_column = "Predicted LB"
predicted_obj_column = "Objective: Predicted LB"
palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
fig, (ax1, ax2, ax3, ax4) = plt.subplots(