Make get_instance_features return np.ndarray

This commit is contained in:
2021-08-09 14:02:14 -05:00
parent 47d3011808
commit 56b39b6c9c
6 changed files with 29 additions and 16 deletions

View File

@@ -6,6 +6,8 @@ import logging
from abc import ABC, abstractmethod
from typing import Any, List, TYPE_CHECKING, Dict
import numpy as np
from miplearn.features.sample import Sample, MemorySample
logger = logging.getLogger(__name__)
@@ -37,7 +39,7 @@ class Instance(ABC):
"""
pass
def get_instance_features(self) -> List[float]:
def get_instance_features(self) -> np.ndarray:
"""
Returns a 1-dimensional array of (numerical) features describing the
entire instance.
@@ -59,7 +61,7 @@ class Instance(ABC):
By default, returns [0.0].
"""
return [0.0]
return np.zeros(1)
def get_variable_features(self) -> Dict[str, List[float]]:
"""

View File

@@ -6,6 +6,7 @@ import os
from typing import Any, Optional, List, Dict, TYPE_CHECKING
import pickle
import numpy as np
from overrides import overrides
from miplearn.features.sample import Hdf5Sample, Sample
@@ -30,7 +31,7 @@ class FileInstance(Instance):
return self.instance.to_model()
@overrides
def get_instance_features(self) -> List[float]:
def get_instance_features(self) -> np.ndarray:
assert self.instance is not None
return self.instance.get_instance_features()

View File

@@ -8,6 +8,7 @@ import os
import pickle
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
import numpy as np
from overrides import overrides
from miplearn.features.sample import Sample
@@ -42,7 +43,7 @@ class PickleGzInstance(Instance):
return self.instance.to_model()
@overrides
def get_instance_features(self) -> List[float]:
def get_instance_features(self) -> np.ndarray:
assert self.instance is not None
return self.instance.get_instance_features()