mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Sample: do not check data by default; minor fixes
This commit is contained in:
@@ -110,10 +110,12 @@ class MemorySample(Sample):
|
||||
def __init__(
|
||||
self,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
check_data: bool = False,
|
||||
) -> None:
|
||||
if data is None:
|
||||
data = {}
|
||||
self._data: Dict[str, Any] = data
|
||||
self._check_data = check_data
|
||||
|
||||
@overrides
|
||||
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||
@@ -142,19 +144,22 @@ class MemorySample(Sample):
|
||||
def put_scalar(self, key: str, value: Scalar) -> None:
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_scalar(value)
|
||||
if self._check_data:
|
||||
self._assert_is_scalar(value)
|
||||
self._put(key, value)
|
||||
|
||||
@overrides
|
||||
def put_vector(self, key: str, value: Vector) -> None:
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_vector(value)
|
||||
if self._check_data:
|
||||
self._assert_is_vector(value)
|
||||
self._put(key, value)
|
||||
|
||||
@overrides
|
||||
def put_vector_list(self, key: str, value: VectorList) -> None:
|
||||
self._assert_is_vector_list(value)
|
||||
if self._check_data:
|
||||
self._assert_is_vector_list(value)
|
||||
self._put(key, value)
|
||||
|
||||
def _get(self, key: str) -> Optional[Any]:
|
||||
@@ -175,8 +180,14 @@ class Hdf5Sample(Sample):
|
||||
are actually accessed, and therefore it is more scalable.
|
||||
"""
|
||||
|
||||
def __init__(self, filename: str, mode: str = "r+") -> None:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
mode: str = "r+",
|
||||
check_data: bool = False,
|
||||
) -> None:
|
||||
self.file = h5py.File(filename, mode, libver="latest")
|
||||
self._check_data = check_data
|
||||
|
||||
@overrides
|
||||
def get_bytes(self, key: str) -> Optional[Bytes]:
|
||||
@@ -230,27 +241,30 @@ class Hdf5Sample(Sample):
|
||||
|
||||
@overrides
|
||||
def put_bytes(self, key: str, value: Bytes) -> None:
|
||||
assert isinstance(
|
||||
value, (bytes, bytearray)
|
||||
), f"bytes expected; found: {value}" # type: ignore
|
||||
if self._check_data:
|
||||
assert isinstance(
|
||||
value, (bytes, bytearray)
|
||||
), f"bytes expected; found: {value}" # type: ignore
|
||||
self._put(key, np.frombuffer(value, dtype="uint8"), compress=True)
|
||||
|
||||
@overrides
|
||||
def put_scalar(self, key: str, value: Any) -> None:
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_scalar(value)
|
||||
if self._check_data:
|
||||
self._assert_is_scalar(value)
|
||||
self._put(key, value)
|
||||
|
||||
@overrides
|
||||
def put_vector(self, key: str, value: Vector) -> None:
|
||||
if value is None:
|
||||
return
|
||||
self._assert_is_vector(value)
|
||||
if self._check_data:
|
||||
self._assert_is_vector(value)
|
||||
|
||||
for v in value:
|
||||
# Convert strings to bytes
|
||||
if isinstance(v, str):
|
||||
if isinstance(v, str) or v is None:
|
||||
value = np.array(
|
||||
[u if u is not None else b"" for u in value],
|
||||
dtype="S",
|
||||
@@ -266,7 +280,8 @@ class Hdf5Sample(Sample):
|
||||
|
||||
@overrides
|
||||
def put_vector_list(self, key: str, value: VectorList) -> None:
|
||||
self._assert_is_vector_list(value)
|
||||
if self._check_data:
|
||||
self._assert_is_vector_list(value)
|
||||
padded, lens = _pad(value)
|
||||
self.put_vector(f"{key}_lengths", lens)
|
||||
data = None
|
||||
@@ -297,7 +312,6 @@ class Hdf5Sample(Sample):
|
||||
|
||||
|
||||
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
|
||||
veclist = deepcopy(veclist)
|
||||
lens = [len(v) if v is not None else -1 for v in veclist]
|
||||
maxlen = max(lens)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user