Make Hdf5Sample work with bytearray

master
Alinson S. Xavier 4 years ago
parent a69cbed7b7
commit 7d5ec1344a
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -11,6 +11,7 @@ import numpy as np
from h5py import Dataset from h5py import Dataset
from overrides import overrides from overrides import overrides
Bytes = Union[bytes, bytearray]
Scalar = Union[None, bool, str, int, float] Scalar = Union[None, bool, str, int, float]
Vector = Union[ Vector = Union[
None, None,
@ -37,11 +38,11 @@ class Sample(ABC):
"""Abstract dictionary-like class that stores training data.""" """Abstract dictionary-like class that stores training data."""
@abstractmethod @abstractmethod
def get_bytes(self, key: str) -> Optional[bytes]: def get_bytes(self, key: str) -> Optional[Bytes]:
pass pass
@abstractmethod @abstractmethod
def put_bytes(self, key: str, value: bytes) -> None: def put_bytes(self, key: str, value: Bytes) -> None:
pass pass
@abstractmethod @abstractmethod
@ -115,7 +116,7 @@ class MemorySample(Sample):
self._data: Dict[str, Any] = data self._data: Dict[str, Any] = data
@overrides @overrides
def get_bytes(self, key: str) -> Optional[bytes]: def get_bytes(self, key: str) -> Optional[Bytes]:
return self._get(key) return self._get(key)
@overrides @overrides
@ -131,8 +132,10 @@ class MemorySample(Sample):
return self._get(key) return self._get(key)
@overrides @overrides
def put_bytes(self, key: str, value: bytes) -> None: def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(value, bytes), f"bytes expected; found: {value}" assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value}" # type: ignore
self._put(key, value) self._put(key, value)
@overrides @overrides
@ -176,7 +179,7 @@ class Hdf5Sample(Sample):
self.file = h5py.File(filename, mode) self.file = h5py.File(filename, mode)
@overrides @overrides
def get_bytes(self, key: str) -> Optional[bytes]: def get_bytes(self, key: str) -> Optional[Bytes]:
if key not in self.file: if key not in self.file:
return None return None
ds = self.file[key] ds = self.file[key]
@ -226,8 +229,10 @@ class Hdf5Sample(Sample):
return _crop(padded, lens) return _crop(padded, lens)
@overrides @overrides
def put_bytes(self, key: str, value: bytes) -> None: def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(value, bytes), f"bytes expected; found: {value}" assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value}" # type: ignore
self._put(key, np.frombuffer(value, dtype="uint8")) self._put(key, np.frombuffer(value, dtype="uint8"))
@overrides @overrides

@ -41,6 +41,11 @@ def _test_sample(sample: Sample) -> None:
# Bytes # Bytes
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05") _assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
_assert_roundtrip_bytes(
sample,
bytearray(b"\x00\x01\x02\x03\x04\x05"),
check_type=False,
)
# Querying unknown keys should return None # Querying unknown keys should return None
assert sample.get_scalar("unknown-key") is None assert sample.get_scalar("unknown-key") is None
@ -53,12 +58,15 @@ def _test_sample(sample: Sample) -> None:
sample.put_vector("key", None) sample.put_vector("key", None)
def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None: def _assert_roundtrip_bytes(
sample: Sample, expected: Any, check_type: bool = False
) -> None:
sample.put_bytes("key", expected) sample.put_bytes("key", expected)
actual = sample.get_bytes("key") actual = sample.get_bytes("key")
assert actual == expected assert actual == expected
assert actual is not None assert actual is not None
_assert_same_type(actual, expected) if check_type:
_assert_same_type(actual, expected)
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:

Loading…
Cancel
Save