diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index b0ff220..75c79b4 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -267,7 +267,10 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None: def _equals_preprocess(obj: Any) -> Any: if isinstance(obj, np.ndarray): - return np.round(obj, decimals=6).tolist() + if obj.dtype == "float64": + return np.round(obj, decimals=6).tolist() + else: + return obj.tolist() elif isinstance(obj, (int, str)): return obj elif isinstance(obj, float): diff --git a/tests/test_features.py b/tests/test_features.py index 446f895..f14fd18 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -150,5 +150,9 @@ def test_assert_equals() -> None: VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore VariableFeatures(values=np.array([1.0, 2.0])), # type: ignore ) + assert_equals( + np.array([True, True]), + [True, True], + ) assert_equals((1.0,), (1.0,)) assert_equals({"x": 10}, {"x": 10})