mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Fix benchmark scripts; add more input checks
This commit is contained in:
@@ -61,7 +61,7 @@ def read_pickle_gz(filename):
|
|||||||
return pickle.load(file)
|
return pickle.load(file)
|
||||||
|
|
||||||
|
|
||||||
def write_multiple(objs, dirname):
|
def write_pickle_gz_multiple(objs, dirname):
|
||||||
for (i, obj) in enumerate(objs):
|
for (i, obj) in enumerate(objs):
|
||||||
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
||||||
|
|
||||||
@@ -73,8 +73,8 @@ def train(args):
|
|||||||
challenge = getattr(pkg, challenge_name)()
|
challenge = getattr(pkg, challenge_name)()
|
||||||
|
|
||||||
if not os.path.isdir(f"{basepath}/train"):
|
if not os.path.isdir(f"{basepath}/train"):
|
||||||
write_multiple(challenge.training_instances, f"{basepath}/train")
|
write_pickle_gz_multiple(challenge.training_instances, f"{basepath}/train")
|
||||||
write_multiple(challenge.test_instances, f"{basepath}/test")
|
write_pickle_gz_multiple(challenge.test_instances, f"{basepath}/test")
|
||||||
|
|
||||||
done_filename = f"{basepath}/train/done"
|
done_filename = f"{basepath}/train/done"
|
||||||
if not os.path.isfile(done_filename):
|
if not os.path.isfile(done_filename):
|
||||||
@@ -114,7 +114,7 @@ def test_baseline(args):
|
|||||||
test_instances,
|
test_instances,
|
||||||
n_jobs=int(args["--test-jobs"]),
|
n_jobs=int(args["--test-jobs"]),
|
||||||
)
|
)
|
||||||
benchmark.save_results(csv_filename)
|
benchmark.write_csv(csv_filename)
|
||||||
|
|
||||||
|
|
||||||
def test_ml(args):
|
def test_ml(args):
|
||||||
@@ -148,7 +148,7 @@ def test_ml(args):
|
|||||||
test_instances,
|
test_instances,
|
||||||
n_jobs=int(args["--test-jobs"]),
|
n_jobs=int(args["--test-jobs"]),
|
||||||
)
|
)
|
||||||
benchmark.save_results(csv_filename)
|
benchmark.write_csv(csv_filename)
|
||||||
|
|
||||||
|
|
||||||
def charts(args):
|
def charts(args):
|
||||||
@@ -171,11 +171,11 @@ def charts(args):
|
|||||||
if (sense == "min").any():
|
if (sense == "min").any():
|
||||||
primal_column = "Relative upper bound"
|
primal_column = "Relative upper bound"
|
||||||
obj_column = "Upper bound"
|
obj_column = "Upper bound"
|
||||||
predicted_obj_column = "Predicted UB"
|
predicted_obj_column = "Objective: Predicted UB"
|
||||||
else:
|
else:
|
||||||
primal_column = "Relative lower bound"
|
primal_column = "Relative lower bound"
|
||||||
obj_column = "Lower bound"
|
obj_column = "Lower bound"
|
||||||
predicted_obj_column = "Predicted LB"
|
predicted_obj_column = "Objective: Predicted LB"
|
||||||
|
|
||||||
palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
|
palette = {"baseline": "#9b59b6", "ml-exact": "#3498db", "ml-heuristic": "#95a5a6"}
|
||||||
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
|
fig, (ax1, ax2, ax3, ax4) = plt.subplots(
|
||||||
|
|||||||
@@ -112,8 +112,14 @@ class Regressor(ABC):
|
|||||||
assert isinstance(y_train, np.ndarray)
|
assert isinstance(y_train, np.ndarray)
|
||||||
assert x_train.dtype in [np.float16, np.float32, np.float64]
|
assert x_train.dtype in [np.float16, np.float32, np.float64]
|
||||||
assert y_train.dtype in [np.float16, np.float32, np.float64]
|
assert y_train.dtype in [np.float16, np.float32, np.float64]
|
||||||
assert len(x_train.shape) == 2
|
assert len(x_train.shape) == 2, (
|
||||||
assert len(y_train.shape) == 2
|
f"Parameter x_train should be a square matrix. "
|
||||||
|
f"Found {x_train.shape} ndarray instead."
|
||||||
|
)
|
||||||
|
assert len(y_train.shape) == 2, (
|
||||||
|
f"Parameter y_train should be a square matrix. "
|
||||||
|
f"Found {y_train.shape} ndarray instead."
|
||||||
|
)
|
||||||
(n_samples_x, n_inputs) = x_train.shape
|
(n_samples_x, n_inputs) = x_train.shape
|
||||||
(n_samples_y, n_outputs) = y_train.shape
|
(n_samples_y, n_outputs) = y_train.shape
|
||||||
assert n_samples_y == n_samples_x
|
assert n_samples_y == n_samples_x
|
||||||
|
|||||||
@@ -41,12 +41,13 @@ class FeaturesExtractor:
|
|||||||
assert isinstance(user_features, list), (
|
assert isinstance(user_features, list), (
|
||||||
f"Variable features must be a list. "
|
f"Variable features must be a list. "
|
||||||
f"Found {type(user_features).__name__} instead for "
|
f"Found {type(user_features).__name__} instead for "
|
||||||
f"var={var_name}."
|
f"var={var_name}[{idx}]."
|
||||||
)
|
)
|
||||||
assert isinstance(user_features[0], numbers.Real), (
|
for v in user_features:
|
||||||
|
assert isinstance(v, numbers.Real), (
|
||||||
f"Variable features must be a list of numbers. "
|
f"Variable features must be a list of numbers. "
|
||||||
f"Found {type(user_features[0]).__name__} instead "
|
f"Found {type(v).__name__} instead "
|
||||||
f"for var={var_name}."
|
f"for var={var_name}[{idx}]."
|
||||||
)
|
)
|
||||||
var_dict[idx] = {
|
var_dict[idx] = {
|
||||||
"Category": category,
|
"Category": category,
|
||||||
@@ -92,4 +93,15 @@ class FeaturesExtractor:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_instance(instance: "Instance") -> InstanceFeatures:
|
def _extract_instance(instance: "Instance") -> InstanceFeatures:
|
||||||
return {"User features": instance.get_instance_features()}
|
user_features = instance.get_instance_features()
|
||||||
|
assert isinstance(user_features, list), (
|
||||||
|
f"Instance features must be a list. "
|
||||||
|
f"Found {type(user_features).__name__} instead for "
|
||||||
|
f"var={var_name}[{idx}]."
|
||||||
|
)
|
||||||
|
for v in user_features:
|
||||||
|
assert isinstance(v, numbers.Real), (
|
||||||
|
f"Instance features must be a list of numbers. "
|
||||||
|
f"Found {type(v).__name__} instead."
|
||||||
|
)
|
||||||
|
return {"User features": user_features}
|
||||||
|
|||||||
@@ -5,6 +5,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import warnings
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
_formatwarning = warnings.formatwarning
|
||||||
|
|
||||||
|
|
||||||
class TimeFormatter(logging.Formatter):
|
class TimeFormatter(logging.Formatter):
|
||||||
@@ -28,6 +32,13 @@ class TimeFormatter(logging.Formatter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def formatwarning_tb(*args, **kwargs):
|
||||||
|
s = _formatwarning(*args, **kwargs)
|
||||||
|
tb = traceback.format_stack()
|
||||||
|
s += "".join(tb[:-1])
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(start_time=None, force_color=False):
|
def setup_logger(start_time=None, force_color=False):
|
||||||
if start_time is None:
|
if start_time is None:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -49,3 +60,5 @@ def setup_logger(start_time=None, force_color=False):
|
|||||||
handler.setFormatter(TimeFormatter(start_time, log_colors))
|
handler.setFormatter(TimeFormatter(start_time, log_colors))
|
||||||
logging.getLogger().addHandler(handler)
|
logging.getLogger().addHandler(handler)
|
||||||
logging.getLogger("miplearn").setLevel(logging.INFO)
|
logging.getLogger("miplearn").setLevel(logging.INFO)
|
||||||
|
warnings.formatwarning = formatwarning_tb
|
||||||
|
logging.captureWarnings(True)
|
||||||
|
|||||||
@@ -85,24 +85,10 @@ class MultiKnapsackInstance(Instance):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def get_instance_features(self):
|
def get_instance_features(self):
|
||||||
return np.hstack(
|
return [np.mean(self.prices)] + list(self.capacities)
|
||||||
[
|
|
||||||
np.mean(self.prices),
|
|
||||||
self.capacities,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_variable_features(self, var, index):
|
def get_variable_features(self, var, index):
|
||||||
return np.hstack(
|
return [self.prices[index]] + list(self.weights[:, index])
|
||||||
[
|
|
||||||
self.prices[index],
|
|
||||||
self.weights[:, index],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# def get_variable_category(self, var, index):
|
|
||||||
# return index
|
|
||||||
|
|
||||||
|
|
||||||
class MultiKnapsackGenerator:
|
class MultiKnapsackGenerator:
|
||||||
|
|||||||
Reference in New Issue
Block a user