mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make get_instance_features return np.ndarray
This commit is contained in:
@@ -6,6 +6,8 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, TYPE_CHECKING, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from miplearn.features.sample import Sample, MemorySample
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,7 +39,7 @@ class Instance(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_instance_features(self) -> List[float]:
|
||||
def get_instance_features(self) -> np.ndarray:
|
||||
"""
|
||||
Returns a 1-dimensional array of (numerical) features describing the
|
||||
entire instance.
|
||||
@@ -59,7 +61,7 @@ class Instance(ABC):
|
||||
|
||||
By default, returns [0.0].
|
||||
"""
|
||||
return [0.0]
|
||||
return np.zeros(1)
|
||||
|
||||
def get_variable_features(self) -> Dict[str, List[float]]:
|
||||
"""
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
from typing import Any, Optional, List, Dict, TYPE_CHECKING
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Hdf5Sample, Sample
|
||||
@@ -30,7 +31,7 @@ class FileInstance(Instance):
|
||||
return self.instance.to_model()
|
||||
|
||||
@overrides
|
||||
def get_instance_features(self) -> List[float]:
|
||||
def get_instance_features(self) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_instance_features()
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import os
|
||||
import pickle
|
||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
@@ -42,7 +43,7 @@ class PickleGzInstance(Instance):
|
||||
return self.instance.to_model()
|
||||
|
||||
@overrides
|
||||
def get_instance_features(self) -> List[float]:
|
||||
def get_instance_features(self) -> np.ndarray:
|
||||
assert self.instance is not None
|
||||
return self.instance.get_instance_features()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user