Make assert_equals work with np.ndarray

master
Alinson S. Xavier 4 years ago
parent 310394b397
commit cdd38cdfb8
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -4,6 +4,8 @@
from typing import Any, List from typing import Any, List
import numpy as np
from miplearn.features import VariableFeatures, ConstraintFeatures from miplearn.features import VariableFeatures, ConstraintFeatures
from miplearn.solvers.internal import InternalSolver from miplearn.solvers.internal import InternalSolver
@ -282,4 +284,8 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None:
def assert_equals(left: Any, right: Any) -> None: def assert_equals(left: Any, right: Any) -> None:
if isinstance(left, np.ndarray):
left = left.tolist()
if isinstance(right, np.ndarray):
right = right.tolist()
assert left == right, f"left:\n{left}\nright:\n{right}" assert left == right, f"left:\n{left}\nright:\n{right}"

@ -13,6 +13,7 @@ from miplearn.solvers.tests import (
assert_equals, assert_equals,
_round, _round,
) )
import numpy as np
inf = float("inf") inf = float("inf")
@ -134,3 +135,16 @@ def test_constraint_getindex() -> None:
], ],
), ),
) )
def test_assert_equals() -> None:
assert_equals("hello", "hello")
assert_equals([1.0, 2.0], [1.0, 2.0])
assert_equals(
np.array([1.0, 2.0]),
np.array([1.0, 2.0]),
)
assert_equals(
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.array([[1.0, 2.0], [3.0, 4.0]]),
)

Loading…
Cancel
Save