mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Fix benchmark scripts; add more input checks
This commit is contained in:
@@ -61,7 +61,7 @@ def read_pickle_gz(filename):
|
||||
return pickle.load(file)
|
||||
|
||||
|
||||
def write_multiple(objs, dirname):
|
||||
def write_pickle_gz_multiple(objs, dirname):
|
||||
for (i, obj) in enumerate(objs):
|
||||
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
||||
|
||||
@@ -73,8 +73,8 @@ def train(args):
|
||||
challenge = getattr(pkg, challenge_name)()
|
||||
|
||||
if not os.path.isdir(f"{basepath}/train"):
|
||||
write_multiple(challenge.training_instances, f"{basepath}/train")
|
||||
write_multiple(challenge.test_instances, f"{basepath}/test")
|
||||
write_pickle_gz_multiple(challenge.training_instances, f"{basepath}/train")
|
||||
write_pickle_gz_multiple(challenge.test_instances, f"{basepath}/test")
|
||||
|
||||
done_filename = f"{basepath}/train/done"
|
||||
if not os.path.isfile(done_filename):
|
||||
@@ -114,7 +114,7 @@ def test_baseline(args):
|
||||
test_instances,
|
||||
n_jobs=int(args["--test-jobs"]),
|
||||
)
|
||||
benchmark.save_results(csv_filename)
|
||||
benchmark.write_csv(csv_filename)
|
||||
|
||||
|
||||
def test_ml(args):
|
||||
@@ -148,7 +148,7 @@ def test_ml(args):
|
||||
test_instances,
|
||||
n_jobs=int(args["--test-jobs"]),
|
||||
)
|
||||
benchmark.save_results(csv_filename)
|
||||
benchmark.write_csv(csv_filename)
|
||||
|
||||
|
||||
def charts(args):
|
||||
@@ -171,11 +171,11 @@ def charts(args):
|
||||
if (sense == "min").any():
|
||||
primal_column = "Relative upper bound"
|
||||
obj_column = "Upper bound"
|
||||
predicted_obj_column = "Predicted UB"
|
||||
predicted_obj_column = "Objective: Predicted UB"
|
||||
else:
|
||||
primal_column = "Relative lower bound"
|
||||
obj_column = "Lower bound"
|
||||
predicted_obj_column = "Predicted LB"
|
||||
predicted_obj_column = "Objective: Predicted LB"
|
||||
|
||||
palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
|
||||
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
|
||||
|
||||
@@ -112,8 +112,14 @@ class Regressor(ABC):
|
||||
assert isinstance(y_train, np.ndarray)
|
||||
assert x_train.dtype in [np.float16, np.float32, np.float64]
|
||||
assert y_train.dtype in [np.float16, np.float32, np.float64]
|
||||
assert len(x_train.shape) == 2
|
||||
assert len(y_train.shape) == 2
|
||||
assert len(x_train.shape) == 2, (
|
||||
f"Parameter x_train should be a square matrix. "
|
||||
f"Found {x_train.shape} ndarray instead."
|
||||
)
|
||||
assert len(y_train.shape) == 2, (
|
||||
f"Parameter y_train should be a square matrix. "
|
||||
f"Found {y_train.shape} ndarray instead."
|
||||
)
|
||||
(n_samples_x, n_inputs) = x_train.shape
|
||||
(n_samples_y, n_outputs) = y_train.shape
|
||||
assert n_samples_y == n_samples_x
|
||||
|
||||
@@ -41,12 +41,13 @@ class FeaturesExtractor:
|
||||
assert isinstance(user_features, list), (
|
||||
f"Variable features must be a list. "
|
||||
f"Found {type(user_features).__name__} instead for "
|
||||
f"var={var_name}."
|
||||
f"var={var_name}[{idx}]."
|
||||
)
|
||||
assert isinstance(user_features[0], numbers.Real), (
|
||||
for v in user_features:
|
||||
assert isinstance(v, numbers.Real), (
|
||||
f"Variable features must be a list of numbers. "
|
||||
f"Found {type(user_features[0]).__name__} instead "
|
||||
f"for var={var_name}."
|
||||
f"Found {type(v).__name__} instead "
|
||||
f"for var={var_name}[{idx}]."
|
||||
)
|
||||
var_dict[idx] = {
|
||||
"Category": category,
|
||||
@@ -92,4 +93,15 @@ class FeaturesExtractor:
|
||||
|
||||
@staticmethod
|
||||
def _extract_instance(instance: "Instance") -> InstanceFeatures:
|
||||
return {"User features": instance.get_instance_features()}
|
||||
user_features = instance.get_instance_features()
|
||||
assert isinstance(user_features, list), (
|
||||
f"Instance features must be a list. "
|
||||
f"Found {type(user_features).__name__} instead for "
|
||||
f"var={var_name}[{idx}]."
|
||||
)
|
||||
for v in user_features:
|
||||
assert isinstance(v, numbers.Real), (
|
||||
f"Instance features must be a list of numbers. "
|
||||
f"Found {type(v).__name__} instead."
|
||||
)
|
||||
return {"User features": user_features}
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
import traceback
|
||||
|
||||
_formatwarning = warnings.formatwarning
|
||||
|
||||
|
||||
class TimeFormatter(logging.Formatter):
|
||||
@@ -28,6 +32,13 @@ class TimeFormatter(logging.Formatter):
|
||||
)
|
||||
|
||||
|
||||
def formatwarning_tb(*args, **kwargs):
|
||||
s = _formatwarning(*args, **kwargs)
|
||||
tb = traceback.format_stack()
|
||||
s += "".join(tb[:-1])
|
||||
return s
|
||||
|
||||
|
||||
def setup_logger(start_time=None, force_color=False):
|
||||
if start_time is None:
|
||||
start_time = time.time()
|
||||
@@ -49,3 +60,5 @@ def setup_logger(start_time=None, force_color=False):
|
||||
handler.setFormatter(TimeFormatter(start_time, log_colors))
|
||||
logging.getLogger().addHandler(handler)
|
||||
logging.getLogger("miplearn").setLevel(logging.INFO)
|
||||
warnings.formatwarning = formatwarning_tb
|
||||
logging.captureWarnings(True)
|
||||
|
||||
@@ -85,24 +85,10 @@ class MultiKnapsackInstance(Instance):
|
||||
return model
|
||||
|
||||
def get_instance_features(self):
|
||||
return np.hstack(
|
||||
[
|
||||
np.mean(self.prices),
|
||||
self.capacities,
|
||||
]
|
||||
)
|
||||
return [np.mean(self.prices)] + list(self.capacities)
|
||||
|
||||
def get_variable_features(self, var, index):
|
||||
return np.hstack(
|
||||
[
|
||||
self.prices[index],
|
||||
self.weights[:, index],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# def get_variable_category(self, var, index):
|
||||
# return index
|
||||
return [self.prices[index]] + list(self.weights[:, index])
|
||||
|
||||
|
||||
class MultiKnapsackGenerator:
|
||||
|
||||
Reference in New Issue
Block a user