diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 39f3261..71d76e3 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -70,7 +70,9 @@ class Sample(ABC): assert False, f"scalar expected; found instead: {value} ({value.__class__})" def _assert_is_array(self, value: np.ndarray) -> None: - assert isinstance(value, np.ndarray) + assert isinstance( + value, np.ndarray + ), f"np.ndarray expected; found instead: {value.__class__}" assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}" def _assert_is_sparse(self, value: Any) -> None: @@ -205,3 +207,18 @@ class Hdf5Sample(Sample): assert col is not None assert data is not None return coo_matrix((data, (row, col))) + + def get_bytes(self, key: str) -> Optional[Bytes]: + if key not in self.file: + return None + ds = self.file[key] + assert ( + len(ds.shape) == 1 + ), f"1-dimensional array expected; found shape {ds.shape}" + return ds[()].tobytes() + + def put_bytes(self, key: str, value: Bytes) -> None: + assert isinstance( + value, (bytes, bytearray) + ), f"bytes expected; found: {value.__class__}" # type: ignore + self.put_array(key, np.frombuffer(value, dtype="uint8")) diff --git a/miplearn/instance/file.py b/miplearn/instance/file.py index a08d6b2..46e7609 100644 --- a/miplearn/instance/file.py +++ b/miplearn/instance/file.py @@ -111,14 +111,16 @@ class FileInstance(Instance): def load(self) -> None: if self.instance is not None: return - self.instance = pickle.loads(self.h5.get_array("pickled").tobytes()) + pkl = self.h5.get_bytes("pickled") + assert pkl is not None + self.instance = pickle.loads(pkl) assert isinstance(self.instance, Instance) @classmethod def save(cls, instance: Instance, filename: str) -> None: h5 = Hdf5Sample(filename, mode="w") - instance_pkl = np.frombuffer(pickle.dumps(instance), dtype=np.int8) - h5.put_array("pickled", instance_pkl) + instance_pkl = pickle.dumps(instance) + h5.put_bytes("pickled", instance_pkl) @overrides def create_sample(self) -> Sample: diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index 01ff2c2..f4abe83 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -275,7 +275,7 @@ def _equals_preprocess(obj: Any) -> Any: return np.round(obj, decimals=6).tolist() else: return obj.tolist() - elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes)): + elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes, bytearray)): return obj elif isinstance(obj, float): return round(obj, 6) diff --git a/tests/features/test_extractor.py b/tests/features/test_extractor.py index 19b8767..858d976 100644 --- a/tests/features/test_extractor.py +++ b/tests/features/test_extractor.py @@ -11,7 +11,7 @@ import numpy as np import gurobipy as gp from miplearn.features.extractor import FeaturesExtractor -from miplearn.features.sample import Hdf5Sample +from miplearn.features.sample import Hdf5Sample, MemorySample from miplearn.instance.base import Instance from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.internal import Variables, Constraints