mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Improve error messages in assertions
This commit is contained in:
@@ -84,19 +84,19 @@ class Sample(ABC):
|
|||||||
return
|
return
|
||||||
if isinstance(value, (str, bool, int, float)):
|
if isinstance(value, (str, bool, int, float)):
|
||||||
return
|
return
|
||||||
assert False, f"Scalar expected; found instead: {value}"
|
assert False, f"scalar expected; found instead: {value}"
|
||||||
|
|
||||||
def _assert_is_vector(self, value: Any) -> None:
|
def _assert_is_vector(self, value: Any) -> None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
value, (list, np.ndarray)
|
value, (list, np.ndarray)
|
||||||
), f"List or numpy array expected; found instead: {value}"
|
), f"list or numpy array expected; found instead: {value}"
|
||||||
for v in value:
|
for v in value:
|
||||||
self._assert_is_scalar(v)
|
self._assert_is_scalar(v)
|
||||||
|
|
||||||
def _assert_is_vector_list(self, value: Any) -> None:
|
def _assert_is_vector_list(self, value: Any) -> None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
value, (list, np.ndarray)
|
value, (list, np.ndarray)
|
||||||
), f"List or numpy array expected; found instead: {value}"
|
), f"list or numpy array expected; found instead: {value}"
|
||||||
for v in value:
|
for v in value:
|
||||||
if v is None:
|
if v is None:
|
||||||
continue
|
continue
|
||||||
@@ -132,7 +132,7 @@ class MemorySample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_bytes(self, key: str, value: bytes) -> None:
|
def put_bytes(self, key: str, value: bytes) -> None:
|
||||||
assert isinstance(value, bytes)
|
assert isinstance(value, bytes), f"bytes expected; found: {value}"
|
||||||
self._put(key, value)
|
self._put(key, value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -180,7 +180,9 @@ class Hdf5Sample(Sample):
|
|||||||
if key not in self.file:
|
if key not in self.file:
|
||||||
return None
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
assert len(ds.shape) == 1
|
assert (
|
||||||
|
len(ds.shape) == 1
|
||||||
|
), f"1-dimensional array expected; found shape {ds.shape}"
|
||||||
return ds[()].tobytes()
|
return ds[()].tobytes()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -188,7 +190,9 @@ class Hdf5Sample(Sample):
|
|||||||
if key not in self.file:
|
if key not in self.file:
|
||||||
return None
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
assert len(ds.shape) == 0
|
assert (
|
||||||
|
len(ds.shape) == 0
|
||||||
|
), f"0-dimensional array expected; found shape {ds.shape}"
|
||||||
if h5py.check_string_dtype(ds.dtype):
|
if h5py.check_string_dtype(ds.dtype):
|
||||||
return ds.asstr()[()]
|
return ds.asstr()[()]
|
||||||
else:
|
else:
|
||||||
@@ -199,7 +203,9 @@ class Hdf5Sample(Sample):
|
|||||||
if key not in self.file:
|
if key not in self.file:
|
||||||
return None
|
return None
|
||||||
ds = self.file[key]
|
ds = self.file[key]
|
||||||
assert len(ds.shape) == 1
|
assert (
|
||||||
|
len(ds.shape) == 1
|
||||||
|
), f"1-dimensional array expected; found shape {ds.shape}"
|
||||||
if h5py.check_string_dtype(ds.dtype):
|
if h5py.check_string_dtype(ds.dtype):
|
||||||
result = ds.asstr()[:].tolist()
|
result = ds.asstr()[:].tolist()
|
||||||
result = [r if len(r) > 0 else None for r in result]
|
result = [r if len(r) > 0 else None for r in result]
|
||||||
@@ -221,7 +227,7 @@ class Hdf5Sample(Sample):
|
|||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def put_bytes(self, key: str, value: bytes) -> None:
|
def put_bytes(self, key: str, value: bytes) -> None:
|
||||||
assert isinstance(value, bytes)
|
assert isinstance(value, bytes), f"bytes expected; found: {value}"
|
||||||
self._put(key, np.frombuffer(value, dtype="uint8"))
|
self._put(key, np.frombuffer(value, dtype="uint8"))
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
@@ -282,13 +288,13 @@ def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
|
|||||||
elif isinstance(v[0], str):
|
elif isinstance(v[0], str):
|
||||||
constant = ""
|
constant = ""
|
||||||
else:
|
else:
|
||||||
assert False, f"Unsupported data type: {v[0]}"
|
assert False, f"unsupported data type: {v[0]}"
|
||||||
|
|
||||||
# Pad vectors
|
# Pad vectors
|
||||||
for (i, vi) in enumerate(veclist):
|
for (i, vi) in enumerate(veclist):
|
||||||
if vi is None:
|
if vi is None:
|
||||||
vi = veclist[i] = []
|
vi = veclist[i] = []
|
||||||
assert isinstance(vi, list)
|
assert isinstance(vi, list), f"list expected; found: {vi}"
|
||||||
for k in range(len(vi), maxlen):
|
for k in range(len(vi), maxlen):
|
||||||
vi.append(constant)
|
vi.append(constant)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user