Sample: Allow numpy arrays

master
Alinson S. Xavier 4 years ago
parent 6fd839351c
commit fc55a077f2
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -19,6 +19,7 @@ Vector = Union[
List[int], List[int],
List[float], List[float],
List[Optional[str]], List[Optional[str]],
np.ndarray,
] ]
VectorList = Union[ VectorList = Union[
List[List[bool]], List[List[bool]],
@ -86,12 +87,16 @@ class Sample(ABC):
assert False, f"Scalar expected; found instead: {value}" assert False, f"Scalar expected; found instead: {value}"
def _assert_is_vector(self, value: Any) -> None: def _assert_is_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}" assert isinstance(
value, (list, np.ndarray)
), f"List or numpy array expected; found instead: {value}"
for v in value: for v in value:
self._assert_is_scalar(v) self._assert_is_scalar(v)
def _assert_is_vector_list(self, value: Any) -> None: def _assert_is_vector_list(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}" assert isinstance(
value, (list, np.ndarray)
), f"List or numpy array expected; found instead: {value}"
for v in value: for v in value:
if v is None: if v is None:
continue continue

@ -3,8 +3,10 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
import numpy as np
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
from miplearn.solvers.tests import assert_equals
def test_memory_sample() -> None: def test_memory_sample() -> None:
@ -28,6 +30,7 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_vector(sample, [True, True, False]) _assert_roundtrip_vector(sample, [True, True, False])
_assert_roundtrip_vector(sample, [1, 2, 3]) _assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0]) _assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
_assert_roundtrip_vector(sample, np.array([1.0, 2.0, 3.0]), check_type=False)
# VectorList # VectorList
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None]) _assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
@ -66,11 +69,14 @@ def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
_assert_same_type(actual, expected) _assert_same_type(actual, expected)
def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None: def _assert_roundtrip_vector(
sample: Sample, expected: Any, check_type: bool = True
) -> None:
sample.put_vector("key", expected) sample.put_vector("key", expected)
actual = sample.get_vector("key") actual = sample.get_vector("key")
assert actual == expected assert_equals(actual, expected)
assert actual is not None assert actual is not None
if check_type:
_assert_same_type(actual[0], expected[0]) _assert_same_type(actual[0], expected[0])

Loading…
Cancel
Save