Re-add sample.{get,put}_bytes

master
Alinson S. Xavier 4 years ago
parent 256d3d094f
commit 5b3a56f053

@ -70,7 +70,9 @@ class Sample(ABC):
assert False, f"scalar expected; found instead: {value} ({value.__class__})"
def _assert_is_array(self, value: np.ndarray) -> None:
assert isinstance(value, np.ndarray)
assert isinstance(
value, np.ndarray
), f"np.ndarray expected; found instead: {value.__class__}"
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
def _assert_is_sparse(self, value: Any) -> None:
@ -205,3 +207,18 @@ class Hdf5Sample(Sample):
assert col is not None
assert data is not None
return coo_matrix((data, (row, col)))
def get_bytes(self, key: str) -> Optional[Bytes]:
if key not in self.file:
return None
ds = self.file[key]
assert (
len(ds.shape) == 1
), f"1-dimensional array expected; found shape {ds.shape}"
return ds[()].tobytes()
def put_bytes(self, key: str, value: Bytes) -> None:
assert isinstance(
value, (bytes, bytearray)
), f"bytes expected; found: {value.__class__}" # type: ignore
self.put_array(key, np.frombuffer(value, dtype="uint8"))

@ -111,14 +111,16 @@ class FileInstance(Instance):
def load(self) -> None:
if self.instance is not None:
return
self.instance = pickle.loads(self.h5.get_array("pickled").tobytes())
pkl = self.h5.get_bytes("pickled")
assert pkl is not None
self.instance = pickle.loads(pkl)
assert isinstance(self.instance, Instance)
@classmethod
def save(cls, instance: Instance, filename: str) -> None:
h5 = Hdf5Sample(filename, mode="w")
instance_pkl = np.frombuffer(pickle.dumps(instance), dtype=np.int8)
h5.put_array("pickled", instance_pkl)
instance_pkl = pickle.dumps(instance)
h5.put_bytes("pickled", instance_pkl)
@overrides
def create_sample(self) -> Sample:

@ -275,7 +275,7 @@ def _equals_preprocess(obj: Any) -> Any:
return np.round(obj, decimals=6).tolist()
else:
return obj.tolist()
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes)):
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes, bytearray)):
return obj
elif isinstance(obj, float):
return round(obj, 6)

@ -11,7 +11,7 @@ import numpy as np
import gurobipy as gp
from miplearn.features.extractor import FeaturesExtractor
from miplearn.features.sample import Hdf5Sample
from miplearn.features.sample import Hdf5Sample, MemorySample
from miplearn.instance.base import Instance
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import Variables, Constraints

Loading…
Cancel
Save