mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Re-add sample.{get,put}_bytes
This commit is contained in:
@@ -70,7 +70,9 @@ class Sample(ABC):
|
||||
assert False, f"scalar expected; found instead: {value} ({value.__class__})"
|
||||
|
||||
def _assert_is_array(self, value: np.ndarray) -> None:
|
||||
assert isinstance(value, np.ndarray)
|
||||
assert isinstance(
|
||||
value, np.ndarray
|
||||
), f"np.ndarray expected; found instead: {value.__class__}"
|
||||
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
|
||||
|
||||
def _assert_is_sparse(self, value: Any) -> None:
|
||||
@@ -205,3 +207,18 @@ class Hdf5Sample(Sample):
|
||||
assert col is not None
|
||||
assert data is not None
|
||||
return coo_matrix((data, (row, col)))
|
||||
|
||||
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||
if key not in self.file:
|
||||
return None
|
||||
ds = self.file[key]
|
||||
assert (
|
||||
len(ds.shape) == 1
|
||||
), f"1-dimensional array expected; found shape {ds.shape}"
|
||||
return ds[()].tobytes()
|
||||
|
||||
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||
assert isinstance(
|
||||
value, (bytes, bytearray)
|
||||
), f"bytes expected; found: {value.__class__}" # type: ignore
|
||||
self.put_array(key, np.frombuffer(value, dtype="uint8"))
|
||||
|
||||
@@ -111,14 +111,16 @@ class FileInstance(Instance):
|
||||
def load(self) -> None:
|
||||
if self.instance is not None:
|
||||
return
|
||||
self.instance = pickle.loads(self.h5.get_array("pickled").tobytes())
|
||||
pkl = self.h5.get_bytes("pickled")
|
||||
assert pkl is not None
|
||||
self.instance = pickle.loads(pkl)
|
||||
assert isinstance(self.instance, Instance)
|
||||
|
||||
@classmethod
|
||||
def save(cls, instance: Instance, filename: str) -> None:
|
||||
h5 = Hdf5Sample(filename, mode="w")
|
||||
instance_pkl = np.frombuffer(pickle.dumps(instance), dtype=np.int8)
|
||||
h5.put_array("pickled", instance_pkl)
|
||||
instance_pkl = pickle.dumps(instance)
|
||||
h5.put_bytes("pickled", instance_pkl)
|
||||
|
||||
@overrides
|
||||
def create_sample(self) -> Sample:
|
||||
|
||||
@@ -275,7 +275,7 @@ def _equals_preprocess(obj: Any) -> Any:
|
||||
return np.round(obj, decimals=6).tolist()
|
||||
else:
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes)):
|
||||
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes, bytearray)):
|
||||
return obj
|
||||
elif isinstance(obj, float):
|
||||
return round(obj, 6)
|
||||
|
||||
@@ -11,7 +11,7 @@ import numpy as np
|
||||
import gurobipy as gp
|
||||
|
||||
from miplearn.features.extractor import FeaturesExtractor
|
||||
from miplearn.features.sample import Hdf5Sample
|
||||
from miplearn.features.sample import Hdf5Sample, MemorySample
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
from miplearn.solvers.internal import Variables, Constraints
|
||||
|
||||
Reference in New Issue
Block a user