Make get_instance_features return np.ndarray

This commit is contained in:
2021-08-09 14:02:14 -05:00
parent 47d3011808
commit 56b39b6c9c
6 changed files with 29 additions and 16 deletions

View File

@@ -261,11 +261,18 @@ class FeaturesExtractor:
instance: "Instance",
sample: Sample,
) -> None:
features = cast(np.ndarray, instance.get_instance_features())
if isinstance(features, list):
features = np.array(features, dtype=float)
assert isinstance(features, np.ndarray)
assert features.dtype.kind in ["f"], f"Unsupported dtype: {features.dtype}"
features = instance.get_instance_features()
assert isinstance(features, np.ndarray), (
f"Instance features must be a numpy array. "
f"Found {features.__class__} instead."
)
assert len(features.shape) == 1, (
f"Instance features must be a vector. "
f"Found array with shape {features.shape} instead."
)
assert features.dtype.kind in [
"f"
], f"Instance features have unsupported dtype: {features.dtype}"
sample.put_array("static_instance_features", features)
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based