From 56b39b6c9c2f00bde27a0e1103dcf698358da5fa Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Mon, 9 Aug 2021 14:02:14 -0500 Subject: [PATCH] Make get_instance_features return np.ndarray --- miplearn/features/extractor.py | 17 ++++++++++++----- miplearn/instance/base.py | 6 ++++-- miplearn/instance/file.py | 3 ++- miplearn/instance/picklegz.py | 3 ++- miplearn/problems/knapsack.py | 4 ++-- miplearn/solvers/pyomo/base.py | 12 +++++++----- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/miplearn/features/extractor.py b/miplearn/features/extractor.py index 184694a..ead7f8f 100644 --- a/miplearn/features/extractor.py +++ b/miplearn/features/extractor.py @@ -261,11 +261,18 @@ class FeaturesExtractor: instance: "Instance", sample: Sample, ) -> None: - features = cast(np.ndarray, instance.get_instance_features()) - if isinstance(features, list): - features = np.array(features, dtype=float) - assert isinstance(features, np.ndarray) - assert features.dtype.kind in ["f"], f"Unsupported dtype: {features.dtype}" + features = instance.get_instance_features() + assert isinstance(features, np.ndarray), ( + f"Instance features must be a numpy array. " + f"Found {features.__class__} instead." + ) + assert len(features.shape) == 1, ( + f"Instance features must be a vector. " + f"Found array with shape {features.shape} instead." + ) + assert features.dtype.kind in [ + "f" + ], f"Instance features have unsupported dtype: {features.dtype}" sample.put_array("static_instance_features", features) # Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based diff --git a/miplearn/instance/base.py b/miplearn/instance/base.py index 3f0e7b2..1c5ba8a 100644 --- a/miplearn/instance/base.py +++ b/miplearn/instance/base.py @@ -6,6 +6,8 @@ import logging from abc import ABC, abstractmethod from typing import Any, List, TYPE_CHECKING, Dict +import numpy as np + from miplearn.features.sample import Sample, MemorySample logger = logging.getLogger(__name__) @@ -37,7 +39,7 @@ class Instance(ABC): """ pass - def get_instance_features(self) -> List[float]: + def get_instance_features(self) -> np.ndarray: """ Returns a 1-dimensional array of (numerical) features describing the entire instance. @@ -59,7 +61,7 @@ class Instance(ABC): By default, returns [0.0]. """ - return [0.0] + return np.zeros(1) def get_variable_features(self) -> Dict[str, List[float]]: """ diff --git a/miplearn/instance/file.py b/miplearn/instance/file.py index 14f9fdf..daf1816 100644 --- a/miplearn/instance/file.py +++ b/miplearn/instance/file.py @@ -6,6 +6,7 @@ import os from typing import Any, Optional, List, Dict, TYPE_CHECKING import pickle +import numpy as np from overrides import overrides from miplearn.features.sample import Hdf5Sample, Sample @@ -30,7 +31,7 @@ class FileInstance(Instance): return self.instance.to_model() @overrides - def get_instance_features(self) -> List[float]: + def get_instance_features(self) -> np.ndarray: assert self.instance is not None return self.instance.get_instance_features() diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index 8472a9d..b7b6b40 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -8,6 +8,7 @@ import os import pickle from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict +import numpy as np from overrides import overrides from miplearn.features.sample import Sample @@ -42,7 +43,7 @@ class PickleGzInstance(Instance): return self.instance.to_model() @overrides - def get_instance_features(self) -> List[float]: + def get_instance_features(self) -> np.ndarray: assert self.instance is not None return self.instance.get_instance_features() diff --git a/miplearn/problems/knapsack.py b/miplearn/problems/knapsack.py index 83df03f..2a922de 100644 --- a/miplearn/problems/knapsack.py +++ b/miplearn/problems/knapsack.py @@ -94,8 +94,8 @@ class MultiKnapsackInstance(Instance): return model @overrides - def get_instance_features(self) -> List[float]: - return [float(np.mean(self.prices))] + list(self.capacities) + def get_instance_features(self) -> np.ndarray: + return np.array([float(np.mean(self.prices))] + list(self.capacities)) @overrides def get_variable_features(self) -> Dict[str, List[float]]: diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index ed0ba59..acd0e37 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -622,11 +622,13 @@ class PyomoTestInstanceKnapsack(Instance): return model @overrides - def get_instance_features(self) -> List[float]: - return [ - self.capacity, - np.average(self.weights), - ] + def get_instance_features(self) -> np.ndarray: + return np.array( + [ + self.capacity, + np.average(self.weights), + ] + ) @overrides def get_variable_features(self) -> Dict[str, List[float]]: