Fix benchmark scripts; add more input checks

This commit is contained in:
2021-04-03 07:57:22 -05:00
parent 0bce2051a8
commit 7a6b31ca9a
5 changed files with 49 additions and 32 deletions

View File

@@ -112,8 +112,14 @@ class Regressor(ABC):
assert isinstance(y_train, np.ndarray)
assert x_train.dtype in [np.float16, np.float32, np.float64]
assert y_train.dtype in [np.float16, np.float32, np.float64]
assert len(x_train.shape) == 2
assert len(y_train.shape) == 2
assert len(x_train.shape) == 2, (
f"Parameter x_train should be a square matrix. "
f"Found {x_train.shape} ndarray instead."
)
assert len(y_train.shape) == 2, (
f"Parameter y_train should be a square matrix. "
f"Found {y_train.shape} ndarray instead."
)
(n_samples_x, n_inputs) = x_train.shape
(n_samples_y, n_outputs) = y_train.shape
assert n_samples_y == n_samples_x

View File

@@ -41,13 +41,14 @@ class FeaturesExtractor:
assert isinstance(user_features, list), (
f"Variable features must be a list. "
f"Found {type(user_features).__name__} instead for "
f"var={var_name}."
)
assert isinstance(user_features[0], numbers.Real), (
f"Variable features must be a list of numbers."
f"Found {type(user_features[0]).__name__} instead "
f"for var={var_name}."
f"var={var_name}[{idx}]."
)
for v in user_features:
assert isinstance(v, numbers.Real), (
f"Variable features must be a list of numbers. "
f"Found {type(v).__name__} instead "
f"for var={var_name}[{idx}]."
)
var_dict[idx] = {
"Category": category,
"User features": user_features,
@@ -92,4 +93,15 @@ class FeaturesExtractor:
@staticmethod
def _extract_instance(instance: "Instance") -> InstanceFeatures:
return {"User features": instance.get_instance_features()}
user_features = instance.get_instance_features()
assert isinstance(user_features, list), (
f"Instance features must be a list. "
f"Found {type(user_features).__name__} instead for "
f"var={var_name}[{idx}]."
)
for v in user_features:
assert isinstance(v, numbers.Real), (
f"Instance features must be a list of numbers. "
f"Found {type(v).__name__} instead."
)
return {"User features": user_features}

View File

@@ -5,6 +5,10 @@
import logging
import sys
import time
import warnings
import traceback
_formatwarning = warnings.formatwarning
class TimeFormatter(logging.Formatter):
@@ -28,6 +32,13 @@ class TimeFormatter(logging.Formatter):
)
def formatwarning_tb(*args, **kwargs):
s = _formatwarning(*args, **kwargs)
tb = traceback.format_stack()
s += "".join(tb[:-1])
return s
def setup_logger(start_time=None, force_color=False):
if start_time is None:
start_time = time.time()
@@ -49,3 +60,5 @@ def setup_logger(start_time=None, force_color=False):
handler.setFormatter(TimeFormatter(start_time, log_colors))
logging.getLogger().addHandler(handler)
logging.getLogger("miplearn").setLevel(logging.INFO)
warnings.formatwarning = formatwarning_tb
logging.captureWarnings(True)

View File

@@ -85,24 +85,10 @@ class MultiKnapsackInstance(Instance):
return model
def get_instance_features(self):
return np.hstack(
[
np.mean(self.prices),
self.capacities,
]
)
return [np.mean(self.prices)] + list(self.capacities)
def get_variable_features(self, var, index):
return np.hstack(
[
self.prices[index],
self.weights[:, index],
]
)
# def get_variable_category(self, var, index):
# return index
return [self.prices[index]] + list(self.weights[:, index])
class MultiKnapsackGenerator: