mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 17:38:51 -06:00
Implement sample.{get,put}_bytes
This commit is contained in:
@@ -35,6 +35,14 @@ VectorList = Union[
|
|||||||
class Sample(ABC):
|
class Sample(ABC):
|
||||||
"""Abstract dictionary-like class that stores training data."""
|
"""Abstract dictionary-like class that stores training data."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_bytes(self, key: str) -> Optional[bytes]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def put_bytes(self, key: str, value: bytes) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_scalar(self, key: str) -> Optional[Any]:
|
def get_scalar(self, key: str) -> Optional[Any]:
|
||||||
pass
|
pass
|
||||||
@@ -101,6 +109,10 @@ class MemorySample(Sample):
|
|||||||
data = {}
|
data = {}
|
||||||
self._data: Dict[str, Any] = data
|
self._data: Dict[str, Any] = data
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def get_bytes(self, key: str) -> Optional[bytes]:
|
||||||
|
return self._get(key)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_scalar(self, key: str) -> Optional[Any]:
|
def get_scalar(self, key: str) -> Optional[Any]:
|
||||||
return self._get(key)
|
return self._get(key)
|
||||||
@@ -113,6 +125,11 @@ class MemorySample(Sample):
|
|||||||
def get_vector_list(self, key: str) -> Optional[Any]:
|
def get_vector_list(self, key: str) -> Optional[Any]:
|
||||||
return self._get(key)
|
return self._get(key)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def put_bytes(self, key: str, value: bytes) -> None:
|
||||||
|
assert isinstance(value, bytes)
|
||||||
|
self._put(key, value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_scalar(self, key: str, value: Scalar) -> None:
|
def put_scalar(self, key: str, value: Scalar) -> None:
|
||||||
self._assert_is_scalar(value)
|
self._assert_is_scalar(value)
|
||||||
@@ -151,6 +168,12 @@ class Hdf5Sample(Sample):
|
|||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str) -> None:
|
||||||
self.file = h5py.File(filename, "r+")
|
self.file = h5py.File(filename, "r+")
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def get_bytes(self, key: str) -> Optional[bytes]:
|
||||||
|
ds = self.file[key]
|
||||||
|
assert len(ds.shape) == 1
|
||||||
|
return ds[()].tobytes()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_scalar(self, key: str) -> Optional[Any]:
|
def get_scalar(self, key: str) -> Optional[Any]:
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
@@ -180,6 +203,11 @@ class Hdf5Sample(Sample):
|
|||||||
padded = ds[:].tolist()
|
padded = ds[:].tolist()
|
||||||
return _crop(padded, lens)
|
return _crop(padded, lens)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def put_bytes(self, key: str, value: bytes) -> None:
|
||||||
|
assert isinstance(value, bytes)
|
||||||
|
self._put(key, np.frombuffer(value, dtype="uint8"))
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_scalar(self, key: str, value: Any) -> None:
|
def put_scalar(self, key: str, value: Any) -> None:
|
||||||
self._assert_is_scalar(value)
|
self._assert_is_scalar(value)
|
||||||
|
|||||||
@@ -35,6 +35,17 @@ def _test_sample(sample: Sample) -> None:
|
|||||||
_assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
|
_assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
|
||||||
_assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
|
_assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
|
||||||
|
|
||||||
|
# Bytes
|
||||||
|
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None:
|
||||||
|
sample.put_bytes("key", expected)
|
||||||
|
actual = sample.get_bytes("key")
|
||||||
|
assert actual == expected
|
||||||
|
assert actual is not None
|
||||||
|
_assert_same_type(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
|
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
|
||||||
sample.put_scalar("key", expected)
|
sample.put_scalar("key", expected)
|
||||||
|
|||||||
Reference in New Issue
Block a user