mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
assert_equals: Handle ndarray with booleans
This commit is contained in:
@@ -267,7 +267,10 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None:
|
|||||||
|
|
||||||
def _equals_preprocess(obj: Any) -> Any:
|
def _equals_preprocess(obj: Any) -> Any:
|
||||||
if isinstance(obj, np.ndarray):
|
if isinstance(obj, np.ndarray):
|
||||||
|
if obj.dtype == "float64":
|
||||||
return np.round(obj, decimals=6).tolist()
|
return np.round(obj, decimals=6).tolist()
|
||||||
|
else:
|
||||||
|
return obj.tolist()
|
||||||
elif isinstance(obj, (int, str)):
|
elif isinstance(obj, (int, str)):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, float):
|
elif isinstance(obj, float):
|
||||||
|
|||||||
@@ -150,5 +150,9 @@ def test_assert_equals() -> None:
|
|||||||
VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore
|
VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore
|
||||||
VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore
|
VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore
|
||||||
)
|
)
|
||||||
|
assert_equals(
|
||||||
|
np.array([True, True]),
|
||||||
|
[True, True],
|
||||||
|
)
|
||||||
assert_equals((1.0,), (1.0,))
|
assert_equals((1.0,), (1.0,))
|
||||||
assert_equals({"x": 10}, {"x": 10})
|
assert_equals({"x": 10}, {"x": 10})
|
||||||
|
|||||||
Reference in New Issue
Block a user