Re-add sample.{get,put}_bytes

This commit is contained in:
2021-08-11 06:24:10 -05:00
parent 256d3d094f
commit 5b3a56f053
4 changed files with 25 additions and 6 deletions

View File

@@ -70,7 +70,9 @@ class Sample(ABC):
assert False, f"scalar expected; found instead: {value} ({value.__class__})"
def _assert_is_array(self, value: np.ndarray) -> None:
assert isinstance(value, np.ndarray)
assert isinstance(
value, np.ndarray
), f"np.ndarray expected; found instead: {value.__class__}"
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
def _assert_is_sparse(self, value: Any) -> None:
@@ -205,3 +207,18 @@ class Hdf5Sample(Sample):
assert col is not None
assert data is not None
return coo_matrix((data, (row, col)))
def get_bytes(self, key: str) -> Optional[Bytes]:
if key not in self.file:
return None
ds = self.file[key]
assert (
len(ds.shape) == 1
), f"1-dimensional array expected; found shape {ds.shape}"
return ds[()].tobytes()
def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value.__class__}" # type: ignore
self.put_array(key, np.frombuffer(value, dtype="uint8"))

View File

@@ -111,14 +111,16 @@ class FileInstance(Instance):
def load(self) -> None:
if self.instance is not None:
return
self.instance = pickle.loads(self.h5.get_array("pickled").tobytes())
pkl = self.h5.get_bytes("pickled")
assert pkl is not None
self.instance = pickle.loads(pkl)
assert isinstance(self.instance, Instance)
@classmethod
def save(cls, instance: Instance, filename: str) -> None:
h5 = Hdf5Sample(filename, mode="w")
instance_pkl = np.frombuffer(pickle.dumps(instance), dtype=np.int8)
h5.put_array("pickled", instance_pkl)
instance_pkl = pickle.dumps(instance)
h5.put_bytes("pickled", instance_pkl)
@overrides
def create_sample(self) -> Sample:

View File

@@ -275,7 +275,7 @@ def _equals_preprocess(obj: Any) -> Any:
return np.round(obj, decimals=6).tolist()
else:
return obj.tolist()
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes)):
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes, bytearray)):
return obj
elif isinstance(obj, float):
return round(obj, 6)