Sample: Allow numpy arrays

master
Alinson S. Xavier 4 years ago
parent 6fd839351c
commit fc55a077f2
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -19,6 +19,7 @@ Vector = Union[
List[int],
List[float],
List[Optional[str]],
np.ndarray,
]
VectorList = Union[
List[List[bool]],
@ -86,12 +87,16 @@ class Sample(ABC):
assert False, f"Scalar expected; found instead: {value}"
def _assert_is_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
assert isinstance(
value, (list, np.ndarray)
), f"List or numpy array expected; found instead: {value}"
for v in value:
self._assert_is_scalar(v)
def _assert_is_vector_list(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
assert isinstance(
value, (list, np.ndarray)
), f"List or numpy array expected; found instead: {value}"
for v in value:
if v is None:
continue

@ -3,8 +3,10 @@
# Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile
from typing import Any
import numpy as np
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
from miplearn.solvers.tests import assert_equals
def test_memory_sample() -> None:
@ -28,6 +30,7 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_vector(sample, [True, True, False])
_assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
_assert_roundtrip_vector(sample, np.array([1.0, 2.0, 3.0]), check_type=False)
# VectorList
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
@ -66,12 +69,15 @@ def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
_assert_same_type(actual, expected)
def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None:
def _assert_roundtrip_vector(
sample: Sample, expected: Any, check_type: bool = True
) -> None:
sample.put_vector("key", expected)
actual = sample.get_vector("key")
assert actual == expected
assert_equals(actual, expected)
assert actual is not None
_assert_same_type(actual[0], expected[0])
if check_type:
_assert_same_type(actual[0], expected[0])
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:

Loading…
Cancel
Save