mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Sample: Allow numpy arrays
This commit is contained in:
@@ -19,6 +19,7 @@ Vector = Union[
|
|||||||
List[int],
|
List[int],
|
||||||
List[float],
|
List[float],
|
||||||
List[Optional[str]],
|
List[Optional[str]],
|
||||||
|
np.ndarray,
|
||||||
]
|
]
|
||||||
VectorList = Union[
|
VectorList = Union[
|
||||||
List[List[bool]],
|
List[List[bool]],
|
||||||
@@ -86,12 +87,16 @@ class Sample(ABC):
|
|||||||
assert False, f"Scalar expected; found instead: {value}"
|
assert False, f"Scalar expected; found instead: {value}"
|
||||||
|
|
||||||
def _assert_is_vector(self, value: Any) -> None:
|
def _assert_is_vector(self, value: Any) -> None:
|
||||||
assert isinstance(value, list), f"List expected; found instead: {value}"
|
assert isinstance(
|
||||||
|
value, (list, np.ndarray)
|
||||||
|
), f"List or numpy array expected; found instead: {value}"
|
||||||
for v in value:
|
for v in value:
|
||||||
self._assert_is_scalar(v)
|
self._assert_is_scalar(v)
|
||||||
|
|
||||||
def _assert_is_vector_list(self, value: Any) -> None:
|
def _assert_is_vector_list(self, value: Any) -> None:
|
||||||
assert isinstance(value, list), f"List expected; found instead: {value}"
|
assert isinstance(
|
||||||
|
value, (list, np.ndarray)
|
||||||
|
), f"List or numpy array expected; found instead: {value}"
|
||||||
for v in value:
|
for v in value:
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
|
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
|
||||||
|
from miplearn.solvers.tests import assert_equals
|
||||||
|
|
||||||
|
|
||||||
def test_memory_sample() -> None:
|
def test_memory_sample() -> None:
|
||||||
@@ -28,6 +30,7 @@ def _test_sample(sample: Sample) -> None:
|
|||||||
_assert_roundtrip_vector(sample, [True, True, False])
|
_assert_roundtrip_vector(sample, [True, True, False])
|
||||||
_assert_roundtrip_vector(sample, [1, 2, 3])
|
_assert_roundtrip_vector(sample, [1, 2, 3])
|
||||||
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
|
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
|
||||||
|
_assert_roundtrip_vector(sample, np.array([1.0, 2.0, 3.0]), check_type=False)
|
||||||
|
|
||||||
# VectorList
|
# VectorList
|
||||||
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
|
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
|
||||||
@@ -66,12 +69,15 @@ def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
|
|||||||
_assert_same_type(actual, expected)
|
_assert_same_type(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None:
|
def _assert_roundtrip_vector(
|
||||||
|
sample: Sample, expected: Any, check_type: bool = True
|
||||||
|
) -> None:
|
||||||
sample.put_vector("key", expected)
|
sample.put_vector("key", expected)
|
||||||
actual = sample.get_vector("key")
|
actual = sample.get_vector("key")
|
||||||
assert actual == expected
|
assert_equals(actual, expected)
|
||||||
assert actual is not None
|
assert actual is not None
|
||||||
_assert_same_type(actual[0], expected[0])
|
if check_type:
|
||||||
|
_assert_same_type(actual[0], expected[0])
|
||||||
|
|
||||||
|
|
||||||
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:
|
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user