diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index aa644e6..e75f42e 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -132,6 +132,8 @@ class MemorySample(Sample): @overrides def put_scalar(self, key: str, value: Scalar) -> None: + if value is None: + return self._assert_is_scalar(value) self._put(key, value) @@ -220,6 +222,8 @@ class Hdf5Sample(Sample): @overrides def put_scalar(self, key: str, value: Any) -> None: + if value is None: + return self._assert_is_scalar(value) self._put(key, value) diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 41452c2..e07355f 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -39,11 +39,16 @@ def _test_sample(sample: Sample) -> None: # Bytes _assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05") + # Querying unknown keys should return None assert sample.get_scalar("unknown-key") is None assert sample.get_vector("unknown-key") is None assert sample.get_vector_list("unknown-key") is None assert sample.get_bytes("unknown-key") is None + # Putting None should not modify HDF5 file + sample.put_scalar("key", None) + sample.put_vector("key", None) + def _assert_roundtrip_bytes(sample: Sample, expected: Any) -> None: sample.put_bytes("key", expected)