Implement sample.{get,put}_bytes

master
Alinson S. Xavier 4 years ago
parent 962707e8b7
commit 284ba15db6
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -35,6 +35,14 @@ VectorList = Union[
class Sample(ABC): class Sample(ABC):
"""Abstract dictionary-like class that stores training data.""" """Abstract dictionary-like class that stores training data."""
@abstractmethod
def get_bytes(self, key: str) -> Optional[bytes]:
pass
@abstractmethod
def put_bytes(self, key: str, value: bytes) -> None:
pass
@abstractmethod @abstractmethod
def get_scalar(self, key: str) -> Optional[Any]: def get_scalar(self, key: str) -> Optional[Any]:
pass pass
@ -101,6 +109,10 @@ class MemorySample(Sample):
data = {} data = {}
self._data: Dict[str, Any] = data self._data: Dict[str, Any] = data
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:
return self._get(key)
@overrides @overrides
def get_scalar(self, key: str) -> Optional[Any]: def get_scalar(self, key: str) -> Optional[Any]:
return self._get(key) return self._get(key)
@ -113,6 +125,11 @@ class MemorySample(Sample):
def get_vector_list(self, key: str) -> Optional[Any]: def get_vector_list(self, key: str) -> Optional[Any]:
return self._get(key) return self._get(key)
@overrides
def put_bytes(self, key: str, value: bytes) -> None:
assert isinstance(value, bytes)
self._put(key, value)
@overrides @overrides
def put_scalar(self, key: str, value: Scalar) -> None: def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value) self._assert_is_scalar(value)
@ -151,6 +168,12 @@ class Hdf5Sample(Sample):
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
self.file = h5py.File(filename, "r+") self.file = h5py.File(filename, "r+")
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:
ds = self.file[key]
assert len(ds.shape) == 1
return ds[()].tobytes()
@overrides @overrides
def get_scalar(self, key: str) -> Optional[Any]: def get_scalar(self, key: str) -> Optional[Any]:
ds = self.file[key] ds = self.file[key]
@ -180,6 +203,11 @@ class Hdf5Sample(Sample):
padded = ds[:].tolist() padded = ds[:].tolist()
return _crop(padded, lens) return _crop(padded, lens)
@overrides
def put_bytes(self, key: str, value: bytes) -> None:
assert isinstance(value, bytes)
self._put(key, np.frombuffer(value, dtype="uint8"))
@overrides @overrides
def put_scalar(self, key: str, value: Any) -> None: def put_scalar(self, key: str, value: Any) -> None:
self._assert_is_scalar(value) self._assert_is_scalar(value)

@ -35,6 +35,17 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]]) _assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
_assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]]) _assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
# Bytes
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None:
sample.put_bytes("key", expected)
actual = sample.get_bytes("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual, expected)
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
sample.put_scalar("key", expected) sample.put_scalar("key", expected)

Loading…
Cancel
Save