diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 078c847..5f6f85f 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -19,6 +19,7 @@ Vector = Union[ List[int], List[float], List[Optional[str]], + np.ndarray, ] VectorList = Union[ List[List[bool]], @@ -86,12 +87,16 @@ class Sample(ABC): assert False, f"Scalar expected; found instead: {value}" def _assert_is_vector(self, value: Any) -> None: - assert isinstance(value, list), f"List expected; found instead: {value}" + assert isinstance( + value, (list, np.ndarray) + ), f"List or numpy array expected; found instead: {value}" for v in value: self._assert_is_scalar(v) def _assert_is_vector_list(self, value: Any) -> None: - assert isinstance(value, list), f"List expected; found instead: {value}" + assert isinstance( + value, (list, np.ndarray) + ), f"List or numpy array expected; found instead: {value}" for v in value: if v is None: continue diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index e07355f..3051470 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -3,8 +3,10 @@ # Released under the modified BSD license. See COPYING.md for more details. from tempfile import NamedTemporaryFile from typing import Any +import numpy as np from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop +from miplearn.solvers.tests import assert_equals def test_memory_sample() -> None: @@ -28,6 +30,7 @@ def _test_sample(sample: Sample) -> None: _assert_roundtrip_vector(sample, [True, True, False]) _assert_roundtrip_vector(sample, [1, 2, 3]) _assert_roundtrip_vector(sample, [1.0, 2.0, 3.0]) + _assert_roundtrip_vector(sample, np.array([1.0, 2.0, 3.0]), check_type=False) # VectorList _assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None]) @@ -66,12 +69,15 @@ def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: _assert_same_type(actual, expected) -def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None: +def _assert_roundtrip_vector( + sample: Sample, expected: Any, check_type: bool = True +) -> None: sample.put_vector("key", expected) actual = sample.get_vector("key") - assert actual == expected + assert_equals(actual, expected) assert actual is not None - _assert_same_type(actual[0], expected[0]) + if check_type: + _assert_same_type(actual[0], expected[0]) def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None: