diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 41c7f52..078c847 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -207,7 +207,7 @@ class Hdf5Sample(Sample): if key not in self.file: return None ds = self.file[key] - lens = ds.attrs["lengths"] + lens = self.get_vector(f"{key}_lengths") if h5py.check_string_dtype(ds.dtype): padded = ds.asstr()[:].tolist() else: @@ -238,6 +238,7 @@ class Hdf5Sample(Sample): def put_vector_list(self, key: str, value: VectorList) -> None: self._assert_is_vector_list(value) padded, lens = _pad(value) + self.put_vector(f"{key}_lengths", lens) data = None for v in value: if v is None or len(v) == 0: @@ -251,8 +252,7 @@ class Hdf5Sample(Sample): break if data is None: data = np.array(padded) - ds = self._put(key, data) - ds.attrs["lengths"] = lens + self._put(key, data) def _put(self, key: str, value: Any) -> Dataset: if key in self.file: