Sample: Allow numpy arrays

This commit is contained in:
2021-07-28 08:21:56 -05:00
parent 6fd839351c
commit fc55a077f2
2 changed files with 16 additions and 5 deletions

View File

@@ -19,6 +19,7 @@ Vector = Union[
List[int],
List[float],
List[Optional[str]],
np.ndarray,
]
VectorList = Union[
List[List[bool]],
@@ -86,12 +87,16 @@ class Sample(ABC):
assert False, f"Scalar expected; found instead: {value}"
def _assert_is_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
assert isinstance(
value, (list, np.ndarray)
), f"List or numpy array expected; found instead: {value}"
for v in value:
self._assert_is_scalar(v)
def _assert_is_vector_list(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
assert isinstance(
value, (list, np.ndarray)
), f"List or numpy array expected; found instead: {value}"
for v in value:
if v is None:
continue