mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Sample: handle None in vectors
This commit is contained in:
@@ -189,7 +189,9 @@ class Hdf5Sample(Sample):
|
||||
assert len(ds.shape) == 1
|
||||
print(ds.dtype)
|
||||
if h5py.check_string_dtype(ds.dtype):
|
||||
return ds.asstr()[:].tolist()
|
||||
result = ds.asstr()[:].tolist()
|
||||
result = [r if len(r) > 0 else None for r in result]
|
||||
return result
|
||||
else:
|
||||
return ds[:].tolist()
|
||||
|
||||
@@ -218,7 +220,8 @@ class Hdf5Sample(Sample):
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_vector(value)
|
||||
self._put(key, value)
|
||||
modified = [v if v is not None else "" for v in value]
|
||||
self._put(key, modified)
|
||||
|
||||
@overrides
|
||||
def put_vector_list(self, key: str, value: VectorList) -> None:
|
||||
|
||||
@@ -24,7 +24,7 @@ def _test_sample(sample: Sample) -> None:
|
||||
_assert_roundtrip_scalar(sample, 1.0)
|
||||
|
||||
# Vector
|
||||
_assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは"])
|
||||
_assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは", None])
|
||||
_assert_roundtrip_vector(sample, [True, True, False])
|
||||
_assert_roundtrip_vector(sample, [1, 2, 3])
|
||||
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
|
||||
|
||||
Reference in New Issue
Block a user