mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make get_instance_features return np.ndarray
This commit is contained in:
@@ -261,11 +261,18 @@ class FeaturesExtractor:
|
|||||||
instance: "Instance",
|
instance: "Instance",
|
||||||
sample: Sample,
|
sample: Sample,
|
||||||
) -> None:
|
) -> None:
|
||||||
features = cast(np.ndarray, instance.get_instance_features())
|
features = instance.get_instance_features()
|
||||||
if isinstance(features, list):
|
assert isinstance(features, np.ndarray), (
|
||||||
features = np.array(features, dtype=float)
|
f"Instance features must be a numpy array. "
|
||||||
assert isinstance(features, np.ndarray)
|
f"Found {features.__class__} instead."
|
||||||
assert features.dtype.kind in ["f"], f"Unsupported dtype: {features.dtype}"
|
)
|
||||||
|
assert len(features.shape) == 1, (
|
||||||
|
f"Instance features must be a vector. "
|
||||||
|
f"Found array with shape {features.shape} instead."
|
||||||
|
)
|
||||||
|
assert features.dtype.kind in [
|
||||||
|
"f"
|
||||||
|
], f"Instance features have unsupported dtype: {features.dtype}"
|
||||||
sample.put_array("static_instance_features", features)
|
sample.put_array("static_instance_features", features)
|
||||||
|
|
||||||
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
|
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, TYPE_CHECKING, Dict
|
from typing import Any, List, TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from miplearn.features.sample import Sample, MemorySample
|
from miplearn.features.sample import Sample, MemorySample
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -37,7 +39,7 @@ class Instance(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_instance_features(self) -> List[float]:
|
def get_instance_features(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Returns a 1-dimensional array of (numerical) features describing the
|
Returns a 1-dimensional array of (numerical) features describing the
|
||||||
entire instance.
|
entire instance.
|
||||||
@@ -59,7 +61,7 @@ class Instance(ABC):
|
|||||||
|
|
||||||
By default, returns [0.0].
|
By default, returns [0.0].
|
||||||
"""
|
"""
|
||||||
return [0.0]
|
return np.zeros(1)
|
||||||
|
|
||||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
from miplearn.features.sample import Hdf5Sample, Sample
|
from miplearn.features.sample import Hdf5Sample, Sample
|
||||||
@@ -30,7 +31,7 @@ class FileInstance(Instance):
|
|||||||
return self.instance.to_model()
|
return self.instance.to_model()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_instance_features(self) -> List[float]:
|
def get_instance_features(self) -> np.ndarray:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.get_instance_features()
|
return self.instance.get_instance_features()
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
from miplearn.features.sample import Sample
|
from miplearn.features.sample import Sample
|
||||||
@@ -42,7 +43,7 @@ class PickleGzInstance(Instance):
|
|||||||
return self.instance.to_model()
|
return self.instance.to_model()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_instance_features(self) -> List[float]:
|
def get_instance_features(self) -> np.ndarray:
|
||||||
assert self.instance is not None
|
assert self.instance is not None
|
||||||
return self.instance.get_instance_features()
|
return self.instance.get_instance_features()
|
||||||
|
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ class MultiKnapsackInstance(Instance):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_instance_features(self) -> List[float]:
|
def get_instance_features(self) -> np.ndarray:
|
||||||
return [float(np.mean(self.prices))] + list(self.capacities)
|
return np.array([float(np.mean(self.prices))] + list(self.capacities))
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||||
|
|||||||
@@ -622,11 +622,13 @@ class PyomoTestInstanceKnapsack(Instance):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_instance_features(self) -> List[float]:
|
def get_instance_features(self) -> np.ndarray:
|
||||||
return [
|
return np.array(
|
||||||
|
[
|
||||||
self.capacity,
|
self.capacity,
|
||||||
np.average(self.weights),
|
np.average(self.weights),
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||||
|
|||||||
Reference in New Issue
Block a user