diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index b9540c4..a93016d 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -4,6 +4,8 @@ from typing import Any, List +import numpy as np + from miplearn.features import VariableFeatures, ConstraintFeatures from miplearn.solvers.internal import InternalSolver @@ -282,4 +284,8 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None: def assert_equals(left: Any, right: Any) -> None: + if isinstance(left, np.ndarray): + left = left.tolist() + if isinstance(right, np.ndarray): + right = right.tolist() assert left == right, f"left:\n{left}\nright:\n{right}" diff --git a/tests/test_features.py b/tests/test_features.py index 5d2cf2d..aa23fb4 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -13,6 +13,7 @@ from miplearn.solvers.tests import ( assert_equals, _round, ) +import numpy as np inf = float("inf") @@ -134,3 +135,16 @@ def test_constraint_getindex() -> None: ], ), ) + + +def test_assert_equals() -> None: + assert_equals("hello", "hello") + assert_equals([1.0, 2.0], [1.0, 2.0]) + assert_equals( + np.array([1.0, 2.0]), + np.array([1.0, 2.0]), + ) + assert_equals( + np.array([[1.0, 2.0], [3.0, 4.0]]), + np.array([[1.0, 2.0], [3.0, 4.0]]), + )