assert_equals: Handle ndarray with booleans

This commit is contained in:
2021-05-20 11:38:35 -05:00
parent 52093eb1c0
commit ddd136c661
2 changed files with 8 additions and 1 deletions

View File

@@ -267,7 +267,10 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None:
def _equals_preprocess(obj: Any) -> Any:
if isinstance(obj, np.ndarray):
return np.round(obj, decimals=6).tolist()
if obj.dtype == "float64":
return np.round(obj, decimals=6).tolist()
else:
return obj.tolist()
elif isinstance(obj, (int, str)):
return obj
elif isinstance(obj, float):

View File

@@ -150,5 +150,9 @@ def test_assert_equals() -> None:
VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore
VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore
)
assert_equals(
np.array([True, True]),
[True, True],
)
assert_equals((1.0,), (1.0,))
assert_equals({"x": 10}, {"x": 10})