Re-add sample.{get,put}_bytes

master
Alinson S. Xavier 4 years ago
parent 256d3d094f
commit 5b3a56f053

@ -70,7 +70,9 @@ class Sample(ABC):
assert False, f"scalar expected; found instead: {value} ({value.__class__})" assert False, f"scalar expected; found instead: {value} ({value.__class__})"
def _assert_is_array(self, value: np.ndarray) -> None: def _assert_is_array(self, value: np.ndarray) -> None:
assert isinstance(value, np.ndarray) assert isinstance(
value, np.ndarray
), f"np.ndarray expected; found instead: {value.__class__}"
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}" assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
def _assert_is_sparse(self, value: Any) -> None: def _assert_is_sparse(self, value: Any) -> None:
@ -205,3 +207,18 @@ class Hdf5Sample(Sample):
assert col is not None assert col is not None
assert data is not None assert data is not None
return coo_matrix((data, (row, col))) return coo_matrix((data, (row, col)))
def get_bytes(self, key: str) -> Optional[Bytes]:
if key not in self.file:
return None
ds = self.file[key]
assert (
len(ds.shape) == 1
), f"1-dimensional array expected; found shape {ds.shape}"
return ds[()].tobytes()
def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value.__class__}" # type: ignore
self.put_array(key, np.frombuffer(value, dtype="uint8"))

@ -111,14 +111,16 @@ class FileInstance(Instance):
def load(self) -> None: def load(self) -> None:
if self.instance is not None: if self.instance is not None:
return return
self.instance = pickle.loads(self.h5.get_array("pickled").tobytes()) pkl = self.h5.get_bytes("pickled")
assert pkl is not None
self.instance = pickle.loads(pkl)
assert isinstance(self.instance, Instance) assert isinstance(self.instance, Instance)
@classmethod @classmethod
def save(cls, instance: Instance, filename: str) -> None: def save(cls, instance: Instance, filename: str) -> None:
h5 = Hdf5Sample(filename, mode="w") h5 = Hdf5Sample(filename, mode="w")
instance_pkl = np.frombuffer(pickle.dumps(instance), dtype=np.int8) instance_pkl = pickle.dumps(instance)
h5.put_array("pickled", instance_pkl) h5.put_bytes("pickled", instance_pkl)
@overrides @overrides
def create_sample(self) -> Sample: def create_sample(self) -> Sample:

@ -275,7 +275,7 @@ def _equals_preprocess(obj: Any) -> Any:
return np.round(obj, decimals=6).tolist() return np.round(obj, decimals=6).tolist()
else: else:
return obj.tolist() return obj.tolist()
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes)): elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes, bytearray)):
return obj return obj
elif isinstance(obj, float): elif isinstance(obj, float):
return round(obj, 6) return round(obj, 6)

@ -11,7 +11,7 @@ import numpy as np
import gurobipy as gp import gurobipy as gp
from miplearn.features.extractor import FeaturesExtractor from miplearn.features.extractor import FeaturesExtractor
from miplearn.features.sample import Hdf5Sample from miplearn.features.sample import Hdf5Sample, MemorySample
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import Variables, Constraints from miplearn.solvers.internal import Variables, Constraints

Loading…
Cancel
Save