Make Hdf5Sample work with bytearray

This commit is contained in:
2021-07-28 09:06:15 -05:00
parent a69cbed7b7
commit 7d5ec1344a
2 changed files with 23 additions and 10 deletions

View File

@@ -11,6 +11,7 @@ import numpy as np
from h5py import Dataset
from overrides import overrides
Bytes = Union[bytes, bytearray]
Scalar = Union[None, bool, str, int, float]
Vector = Union[
None,
@@ -37,11 +38,11 @@ class Sample(ABC):
"""Abstract dictionary-like class that stores training data."""
@abstractmethod
def get_bytes(self, key: str) -> Optional[bytes]:
def get_bytes(self, key: str) -> Optional[Bytes]:
pass
@abstractmethod
def put_bytes(self, key: str, value: bytes) -> None:
def put_bytes(self, key: str, value: Bytes) -> None:
pass
@abstractmethod
@@ -115,7 +116,7 @@ class MemorySample(Sample):
self._data: Dict[str, Any] = data
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:
def get_bytes(self, key: str) -> Optional[Bytes]:
return self._get(key)
@overrides
@@ -131,8 +132,10 @@ class MemorySample(Sample):
return self._get(key)
@overrides
def put_bytes(self, key: str, value: bytes) -> None:
assert isinstance(value, bytes), f"bytes expected; found: {value}"
def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value}" # type: ignore
self._put(key, value)
@overrides
@@ -176,7 +179,7 @@ class Hdf5Sample(Sample):
self.file = h5py.File(filename, mode)
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:
def get_bytes(self, key: str) -> Optional[Bytes]:
if key not in self.file:
return None
ds = self.file[key]
@@ -226,8 +229,10 @@ class Hdf5Sample(Sample):
return _crop(padded, lens)
@overrides
def put_bytes(self, key: str, value: bytes) -> None:
assert isinstance(value, bytes), f"bytes expected; found: {value}"
def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value}" # type: ignore
self._put(key, np.frombuffer(value, dtype="uint8"))
@overrides