From 3da8d532a82f97c239d405d00f7624684c2159fe Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 27 Jul 2021 10:37:02 -0500 Subject: [PATCH] Sample: handle None in vectors --- miplearn/features/sample.py | 7 +++++-- tests/features/test_sample.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index e119500..bf07fc7 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -189,7 +189,9 @@ class Hdf5Sample(Sample): assert len(ds.shape) == 1 print(ds.dtype) if h5py.check_string_dtype(ds.dtype): - return ds.asstr()[:].tolist() + result = ds.asstr()[:].tolist() + result = [r if len(r) > 0 else None for r in result] + return result else: return ds[:].tolist() @@ -218,7 +220,8 @@ class Hdf5Sample(Sample): if value is None: return self._assert_is_vector(value) - self._put(key, value) + modified = [v if v is not None else "" for v in value] + self._put(key, modified) @overrides def put_vector_list(self, key: str, value: VectorList) -> None: diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index f18ef1a..7e4bee7 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -24,7 +24,7 @@ def _test_sample(sample: Sample) -> None: _assert_roundtrip_scalar(sample, 1.0) # Vector - _assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは"]) + _assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは", None]) _assert_roundtrip_vector(sample, [True, True, False]) _assert_roundtrip_vector(sample, [1, 2, 3]) _assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])