mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Fix benchmark scripts; add more input checks
This commit is contained in:
@@ -61,7 +61,7 @@ def read_pickle_gz(filename):
|
||||
return pickle.load(file)
|
||||
|
||||
|
||||
def write_multiple(objs, dirname):
|
||||
def write_pickle_gz_multiple(objs, dirname):
|
||||
for (i, obj) in enumerate(objs):
|
||||
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
||||
|
||||
@@ -73,8 +73,8 @@ def train(args):
|
||||
challenge = getattr(pkg, challenge_name)()
|
||||
|
||||
if not os.path.isdir(f"{basepath}/train"):
|
||||
write_multiple(challenge.training_instances, f"{basepath}/train")
|
||||
write_multiple(challenge.test_instances, f"{basepath}/test")
|
||||
write_pickle_gz_multiple(challenge.training_instances, f"{basepath}/train")
|
||||
write_pickle_gz_multiple(challenge.test_instances, f"{basepath}/test")
|
||||
|
||||
done_filename = f"{basepath}/train/done"
|
||||
if not os.path.isfile(done_filename):
|
||||
@@ -114,7 +114,7 @@ def test_baseline(args):
|
||||
test_instances,
|
||||
n_jobs=int(args["--test-jobs"]),
|
||||
)
|
||||
benchmark.save_results(csv_filename)
|
||||
benchmark.write_csv(csv_filename)
|
||||
|
||||
|
||||
def test_ml(args):
|
||||
@@ -148,7 +148,7 @@ def test_ml(args):
|
||||
test_instances,
|
||||
n_jobs=int(args["--test-jobs"]),
|
||||
)
|
||||
benchmark.save_results(csv_filename)
|
||||
benchmark.write_csv(csv_filename)
|
||||
|
||||
|
||||
def charts(args):
|
||||
@@ -171,11 +171,11 @@ def charts(args):
|
||||
if (sense == "min").any():
|
||||
primal_column = "Relative upper bound"
|
||||
obj_column = "Upper bound"
|
||||
predicted_obj_column = "Predicted UB"
|
||||
predicted_obj_column = "Objective: Predicted UB"
|
||||
else:
|
||||
primal_column = "Relative lower bound"
|
||||
obj_column = "Lower bound"
|
||||
predicted_obj_column = "Predicted LB"
|
||||
predicted_obj_column = "Objective: Predicted LB"
|
||||
|
||||
palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
|
||||
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
|
||||
|
||||
Reference in New Issue
Block a user