diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 40b15af..e119500 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -35,6 +35,14 @@ VectorList = Union[ class Sample(ABC): """Abstract dictionary-like class that stores training data.""" + @abstractmethod + def get_bytes(self, key: str) -> Optional[bytes]: + pass + + @abstractmethod + def put_bytes(self, key: str, value: bytes) -> None: + pass + @abstractmethod def get_scalar(self, key: str) -> Optional[Any]: pass @@ -101,6 +109,10 @@ class MemorySample(Sample): data = {} self._data: Dict[str, Any] = data + @overrides + def get_bytes(self, key: str) -> Optional[bytes]: + return self._get(key) + @overrides def get_scalar(self, key: str) -> Optional[Any]: return self._get(key) @@ -113,6 +125,11 @@ class MemorySample(Sample): def get_vector_list(self, key: str) -> Optional[Any]: return self._get(key) + @overrides + def put_bytes(self, key: str, value: bytes) -> None: + assert isinstance(value, bytes) + self._put(key, value) + @overrides def put_scalar(self, key: str, value: Scalar) -> None: self._assert_is_scalar(value) @@ -151,6 +168,12 @@ class Hdf5Sample(Sample): def __init__(self, filename: str) -> None: self.file = h5py.File(filename, "r+") + @overrides + def get_bytes(self, key: str) -> Optional[bytes]: + ds = self.file[key] + assert len(ds.shape) == 1 + return ds[()].tobytes() + @overrides def get_scalar(self, key: str) -> Optional[Any]: ds = self.file[key] @@ -180,6 +203,11 @@ class Hdf5Sample(Sample): padded = ds[:].tolist() return _crop(padded, lens) + @overrides + def put_bytes(self, key: str, value: bytes) -> None: + assert isinstance(value, bytes) + self._put(key, np.frombuffer(value, dtype="uint8")) + @overrides def put_scalar(self, key: str, value: Any) -> None: self._assert_is_scalar(value) diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 5462210..f18ef1a 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -35,6 +35,17 @@ def _test_sample(sample: Sample) -> None: _assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]]) _assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]]) + # Bytes + _assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05") + + +def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None: + sample.put_bytes("key", expected) + actual = sample.get_bytes("key") + assert actual == expected + assert actual is not None + _assert_same_type(actual, expected) + def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: sample.put_scalar("key", expected)