diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 6407275..aa644e6 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -170,12 +170,16 @@ class Hdf5Sample(Sample): @overrides def get_bytes(self, key: str) -> Optional[bytes]: + if key not in self.file: + return None ds = self.file[key] assert len(ds.shape) == 1 return ds[()].tobytes() @overrides def get_scalar(self, key: str) -> Optional[Any]: + if key not in self.file: + return None ds = self.file[key] assert len(ds.shape) == 0 if h5py.check_string_dtype(ds.dtype): @@ -185,6 +189,8 @@ class Hdf5Sample(Sample): @overrides def get_vector(self, key: str) -> Optional[Any]: + if key not in self.file: + return None ds = self.file[key] assert len(ds.shape) == 1 print(ds.dtype) @@ -197,6 +203,8 @@ class Hdf5Sample(Sample): @overrides def get_vector_list(self, key: str) -> Optional[Any]: + if key not in self.file: + return None ds = self.file[key] lens = ds.attrs["lengths"] if h5py.check_string_dtype(ds.dtype): diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 0b57ef9..41452c2 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -39,6 +39,11 @@ def _test_sample(sample: Sample) -> None: # Bytes _assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05") + assert sample.get_scalar("unknown-key") is None + assert sample.get_vector("unknown-key") is None + assert sample.get_vector_list("unknown-key") is None + assert sample.get_bytes("unknown-key") is None + def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None: sample.put_bytes("key", expected)