diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 0dd6005..65ddab5 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -249,6 +249,7 @@ class Hdf5Sample(Sample): self._assert_is_vector(value) for v in value: + # Convert strings to bytes if isinstance(v, str): value = np.array( [u if u is not None else b"" for u in value], @@ -256,6 +257,11 @@ class Hdf5Sample(Sample): ) break + # Convert all floating point numbers to half-precision + if isinstance(v, float): + value = np.array(value, dtype=np.dtype("f2")) + break + self._put(key, value, compress=True) @overrides @@ -269,6 +275,8 @@ class Hdf5Sample(Sample): continue if isinstance(v[0], str): data = np.array(padded, dtype="S") + elif isinstance(v[0], float): + data = np.array(padded, dtype=np.dtype("f2")) elif isinstance(v[0], bool): data = np.array(padded, dtype=bool) else: