diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 2c0c8f7..5503343 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -11,6 +11,7 @@ import numpy as np from h5py import Dataset from overrides import overrides +Bytes = Union[bytes, bytearray] Scalar = Union[None, bool, str, int, float] Vector = Union[ None, @@ -37,11 +38,11 @@ class Sample(ABC): """Abstract dictionary-like class that stores training data.""" @abstractmethod - def get_bytes(self, key: str) -> Optional[bytes]: + def get_bytes(self, key: str) -> Optional[Bytes]: pass @abstractmethod - def put_bytes(self, key: str, value: bytes) -> None: + def put_bytes(self, key: str, value: Bytes) -> None: pass @abstractmethod @@ -115,7 +116,7 @@ class MemorySample(Sample): self._data: Dict[str, Any] = data @overrides - def get_bytes(self, key: str) -> Optional[bytes]: + def get_bytes(self, key: str) -> Optional[Bytes]: return self._get(key) @overrides @@ -131,8 +132,10 @@ class MemorySample(Sample): return self._get(key) @overrides - def put_bytes(self, key: str, value: bytes) -> None: - assert isinstance(value, bytes), f"bytes expected; found: {value}" + def put_bytes(self, key: str, value: Bytes) -> None: + assert isinstance( + value, (bytes, bytearray) + ), f"bytes expected; found: {value}" # type: ignore self._put(key, value) @overrides @@ -176,7 +179,7 @@ class Hdf5Sample(Sample): self.file = h5py.File(filename, mode) @overrides - def get_bytes(self, key: str) -> Optional[bytes]: + def get_bytes(self, key: str) -> Optional[Bytes]: if key not in self.file: return None ds = self.file[key] @@ -226,8 +229,10 @@ class Hdf5Sample(Sample): return _crop(padded, lens) @overrides - def put_bytes(self, key: str, value: bytes) -> None: - assert isinstance(value, bytes), f"bytes expected; found: {value}" + def put_bytes(self, key: str, value: Bytes) -> None: + assert isinstance( + value, (bytes, bytearray) + ), f"bytes expected; found: {value}" # type: ignore self._put(key, np.frombuffer(value, dtype="uint8")) @overrides diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 3051470..32b80bd 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -41,6 +41,11 @@ def _test_sample(sample: Sample) -> None: # Bytes _assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05") + _assert_roundtrip_bytes( + sample, + bytearray(b"\x00\x01\x02\x03\x04\x05"), + check_type=False, + ) # Querying unknown keys should return None assert sample.get_scalar("unknown-key") is None @@ -53,12 +58,15 @@ def _test_sample(sample: Sample) -> None: sample.put_vector("key", None) -def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None: +def _assert_roundtrip_bytes( + sample: Sample, expected: Any, check_type: bool = False +) -> None: sample.put_bytes("key", expected) actual = sample.get_bytes("key") assert actual == expected assert actual is not None - _assert_same_type(actual, expected) + if check_type: + _assert_same_type(actual, expected) def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: