mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make Hdf5Sample work with bytearray
This commit is contained in:
@@ -11,6 +11,7 @@ import numpy as np
|
|||||||
from h5py import Dataset
|
from h5py import Dataset
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
|
Bytes = Union[bytes, bytearray]
|
||||||
Scalar = Union[None, bool, str, int, float]
|
Scalar = Union[None, bool, str, int, float]
|
||||||
Vector = Union[
|
Vector = Union[
|
||||||
None,
|
None,
|
||||||
@@ -37,11 +38,11 @@ class Sample(ABC):
|
|||||||
"""Abstract dictionary-like class that stores training data."""
|
"""Abstract dictionary-like class that stores training data."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_bytes(self, key: str) -> Optional[bytes]:
|
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def put_bytes(self, key: str, value: bytes) -> None:
|
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -115,7 +116,7 @@ class MemorySample(Sample):
|
|||||||
self._data: Dict[str, Any] = data
|
self._data: Dict[str, Any] = data
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_bytes(self, key: str) -> Optional[bytes]:
|
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||||
return self._get(key)
|
return self._get(key)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -131,8 +132,10 @@ class MemorySample(Sample):
|
|||||||
return self._get(key)
|
return self._get(key)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_bytes(self, key: str, value: bytes) -> None:
|
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||||
assert isinstance(value, bytes), f"bytes expected; found: {value}"
|
assert isinstance(
|
||||||
|
value, (bytes, bytearray)
|
||||||
|
), f"bytes expected; found: {value}" # type: ignore
|
||||||
self._put(key, value)
|
self._put(key, value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -176,7 +179,7 @@ class Hdf5Sample(Sample):
|
|||||||
self.file = h5py.File(filename, mode)
|
self.file = h5py.File(filename, mode)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_bytes(self, key: str) -> Optional[bytes]:
|
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||||
if key not in self.file:
|
if key not in self.file:
|
||||||
return None
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
@@ -226,8 +229,10 @@ class Hdf5Sample(Sample):
|
|||||||
return _crop(padded, lens)
|
return _crop(padded, lens)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_bytes(self, key: str, value: bytes) -> None:
|
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||||
assert isinstance(value, bytes), f"bytes expected; found: {value}"
|
assert isinstance(
|
||||||
|
value, (bytes, bytearray)
|
||||||
|
), f"bytes expected; found: {value}" # type: ignore
|
||||||
self._put(key, np.frombuffer(value, dtype="uint8"))
|
self._put(key, np.frombuffer(value, dtype="uint8"))
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
|||||||
@@ -41,6 +41,11 @@ def _test_sample(sample: Sample) -> None:
|
|||||||
|
|
||||||
# Bytes
|
# Bytes
|
||||||
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
|
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
|
||||||
|
_assert_roundtrip_bytes(
|
||||||
|
sample,
|
||||||
|
bytearray(b"\x00\x01\x02\x03\x04\x05"),
|
||||||
|
check_type=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Querying unknown keys should return None
|
# Querying unknown keys should return None
|
||||||
assert sample.get_scalar("unknown-key") is None
|
assert sample.get_scalar("unknown-key") is None
|
||||||
@@ -53,12 +58,15 @@ def _test_sample(sample: Sample) -> None:
|
|||||||
sample.put_vector("key", None)
|
sample.put_vector("key", None)
|
||||||
|
|
||||||
|
|
||||||
def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None:
|
def _assert_roundtrip_bytes(
|
||||||
|
sample: Sample, expected: Any, check_type: bool = False
|
||||||
|
) -> None:
|
||||||
sample.put_bytes("key", expected)
|
sample.put_bytes("key", expected)
|
||||||
actual = sample.get_bytes("key")
|
actual = sample.get_bytes("key")
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
assert actual is not None
|
assert actual is not None
|
||||||
_assert_same_type(actual, expected)
|
if check_type:
|
||||||
|
_assert_same_type(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
|
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user