mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Re-add sample.{get,put}_bytes
This commit is contained in:
@@ -70,7 +70,9 @@ class Sample(ABC):
|
|||||||
assert False, f"scalar expected; found instead: {value} ({value.__class__})"
|
assert False, f"scalar expected; found instead: {value} ({value.__class__})"
|
||||||
|
|
||||||
def _assert_is_array(self, value: np.ndarray) -> None:
|
def _assert_is_array(self, value: np.ndarray) -> None:
|
||||||
assert isinstance(value, np.ndarray)
|
assert isinstance(
|
||||||
|
value, np.ndarray
|
||||||
|
), f"np.ndarray expected; found instead: {value.__class__}"
|
||||||
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
|
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
|
||||||
|
|
||||||
def _assert_is_sparse(self, value: Any) -> None:
|
def _assert_is_sparse(self, value: Any) -> None:
|
||||||
@@ -205,3 +207,18 @@ class Hdf5Sample(Sample):
|
|||||||
assert col is not None
|
assert col is not None
|
||||||
assert data is not None
|
assert data is not None
|
||||||
return coo_matrix((data, (row, col)))
|
return coo_matrix((data, (row, col)))
|
||||||
|
|
||||||
|
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||||
|
if key not in self.file:
|
||||||
|
return None
|
||||||
|
ds = self.file[key]
|
||||||
|
assert (
|
||||||
|
len(ds.shape) == 1
|
||||||
|
), f"1-dimensional array expected; found shape {ds.shape}"
|
||||||
|
return ds[()].tobytes()
|
||||||
|
|
||||||
|
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||||
|
assert isinstance(
|
||||||
|
value, (bytes, bytearray)
|
||||||
|
), f"bytes expected; found: {value.__class__}" # type: ignore
|
||||||
|
self.put_array(key, np.frombuffer(value, dtype="uint8"))
|
||||||
|
|||||||
@@ -111,14 +111,16 @@ class FileInstance(Instance):
|
|||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
if self.instance is not None:
|
if self.instance is not None:
|
||||||
return
|
return
|
||||||
self.instance = pickle.loads(self.h5.get_array("pickled").tobytes())
|
pkl = self.h5.get_bytes("pickled")
|
||||||
|
assert pkl is not None
|
||||||
|
self.instance = pickle.loads(pkl)
|
||||||
assert isinstance(self.instance, Instance)
|
assert isinstance(self.instance, Instance)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save(cls, instance: Instance, filename: str) -> None:
|
def save(cls, instance: Instance, filename: str) -> None:
|
||||||
h5 = Hdf5Sample(filename, mode="w")
|
h5 = Hdf5Sample(filename, mode="w")
|
||||||
instance_pkl = np.frombuffer(pickle.dumps(instance), dtype=np.int8)
|
instance_pkl = pickle.dumps(instance)
|
||||||
h5.put_array("pickled", instance_pkl)
|
h5.put_bytes("pickled", instance_pkl)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def create_sample(self) -> Sample:
|
def create_sample(self) -> Sample:
|
||||||
|
|||||||
@@ -275,7 +275,7 @@ def _equals_preprocess(obj: Any) -> Any:
|
|||||||
return np.round(obj, decimals=6).tolist()
|
return np.round(obj, decimals=6).tolist()
|
||||||
else:
|
else:
|
||||||
return obj.tolist()
|
return obj.tolist()
|
||||||
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes)):
|
elif isinstance(obj, (int, str, bool, np.bool_, np.bytes_, bytes, bytearray)):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, float):
|
elif isinstance(obj, float):
|
||||||
return round(obj, 6)
|
return round(obj, 6)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import gurobipy as gp
|
import gurobipy as gp
|
||||||
|
|
||||||
from miplearn.features.extractor import FeaturesExtractor
|
from miplearn.features.extractor import FeaturesExtractor
|
||||||
from miplearn.features.sample import Hdf5Sample
|
from miplearn.features.sample import Hdf5Sample, MemorySample
|
||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
from miplearn.solvers.gurobi import GurobiSolver
|
from miplearn.solvers.gurobi import GurobiSolver
|
||||||
from miplearn.solvers.internal import Variables, Constraints
|
from miplearn.solvers.internal import Variables, Constraints
|
||||||
|
|||||||
Reference in New Issue
Block a user