mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Hdf5Sample: Return None for non-existing keys
This commit is contained in:
@@ -170,12 +170,16 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_bytes(self, key: str) -> Optional[bytes]:
|
def get_bytes(self, key: str) -> Optional[bytes]:
|
||||||
|
if key not in self.file:
|
||||||
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
assert len(ds.shape) == 1
|
assert len(ds.shape) == 1
|
||||||
return ds[()].tobytes()
|
return ds[()].tobytes()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_scalar(self, key: str) -> Optional[Any]:
|
def get_scalar(self, key: str) -> Optional[Any]:
|
||||||
|
if key not in self.file:
|
||||||
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
assert len(ds.shape) == 0
|
assert len(ds.shape) == 0
|
||||||
if h5py.check_string_dtype(ds.dtype):
|
if h5py.check_string_dtype(ds.dtype):
|
||||||
@@ -185,6 +189,8 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_vector(self, key: str) -> Optional[Any]:
|
def get_vector(self, key: str) -> Optional[Any]:
|
||||||
|
if key not in self.file:
|
||||||
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
assert len(ds.shape) == 1
|
assert len(ds.shape) == 1
|
||||||
print(ds.dtype)
|
print(ds.dtype)
|
||||||
@@ -197,6 +203,8 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_vector_list(self, key: str) -> Optional[Any]:
|
def get_vector_list(self, key: str) -> Optional[Any]:
|
||||||
|
if key not in self.file:
|
||||||
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
lens = ds.attrs["lengths"]
|
lens = ds.attrs["lengths"]
|
||||||
if h5py.check_string_dtype(ds.dtype):
|
if h5py.check_string_dtype(ds.dtype):
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ def _test_sample(sample: Sample) -> None:
|
|||||||
# Bytes
|
# Bytes
|
||||||
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
|
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
|
||||||
|
|
||||||
|
assert sample.get_scalar("unknown-key") is None
|
||||||
|
assert sample.get_vector("unknown-key") is None
|
||||||
|
assert sample.get_vector_list("unknown-key") is None
|
||||||
|
assert sample.get_bytes("unknown-key") is None
|
||||||
|
|
||||||
|
|
||||||
def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None:
|
def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None:
|
||||||
sample.put_bytes("key", expected)
|
sample.put_bytes("key", expected)
|
||||||
|
|||||||
Reference in New Issue
Block a user