Implement sample.{get,put}_bytes

This commit is contained in:
2021-07-27 10:01:32 -05:00
parent 962707e8b7
commit 284ba15db6
2 changed files with 39 additions and 0 deletions

View File

@@ -35,6 +35,14 @@ VectorList = Union[
class Sample(ABC):
"""Abstract dictionary-like class that stores training data."""
@abstractmethod
def get_bytes(self, key: str) -> Optional[bytes]:
pass
@abstractmethod
def put_bytes(self, key: str, value: bytes) -> None:
pass
@abstractmethod
def get_scalar(self, key: str) -> Optional[Any]:
pass
@@ -101,6 +109,10 @@ class MemorySample(Sample):
data = {}
self._data: Dict[str, Any] = data
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:
return self._get(key)
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
return self._get(key)
@@ -113,6 +125,11 @@ class MemorySample(Sample):
def get_vector_list(self, key: str) -> Optional[Any]:
return self._get(key)
@overrides
def put_bytes(self, key: str, value: bytes) -> None:
assert isinstance(value, bytes)
self._put(key, value)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value)
@@ -151,6 +168,12 @@ class Hdf5Sample(Sample):
def __init__(self, filename: str) -> None:
self.file = h5py.File(filename, "r+")
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:
ds = self.file[key]
assert len(ds.shape) == 1
return ds[()].tobytes()
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
ds = self.file[key]
@@ -180,6 +203,11 @@ class Hdf5Sample(Sample):
padded = ds[:].tolist()
return _crop(padded, lens)
@overrides
def put_bytes(self, key: str, value: bytes) -> None:
assert isinstance(value, bytes)
self._put(key, np.frombuffer(value, dtype="uint8"))
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_is_scalar(value)