Make assert_equals work with np.ndarray

This commit is contained in:
2021-05-20 10:41:38 -05:00
parent 310394b397
commit cdd38cdfb8
2 changed files with 20 additions and 0 deletions

View File

@@ -4,6 +4,8 @@
from typing import Any, List from typing import Any, List
import numpy as np
from miplearn.features import VariableFeatures, ConstraintFeatures from miplearn.features import VariableFeatures, ConstraintFeatures
from miplearn.solvers.internal import InternalSolver from miplearn.solvers.internal import InternalSolver
@@ -282,4 +284,8 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None:
def assert_equals(left: Any, right: Any) -> None: def assert_equals(left: Any, right: Any) -> None:
if isinstance(left, np.ndarray):
left = left.tolist()
if isinstance(right, np.ndarray):
right = right.tolist()
assert left == right, f"left:\n{left}\nright:\n{right}" assert left == right, f"left:\n{left}\nright:\n{right}"

View File

@@ -13,6 +13,7 @@ from miplearn.solvers.tests import (
assert_equals, assert_equals,
_round, _round,
) )
import numpy as np
inf = float("inf") inf = float("inf")
@@ -134,3 +135,16 @@ def test_constraint_getindex() -> None:
], ],
), ),
) )
def test_assert_equals() -> None:
assert_equals("hello", "hello")
assert_equals([1.0, 2.0], [1.0, 2.0])
assert_equals(
np.array([1.0, 2.0]),
np.array([1.0, 2.0]),
)
assert_equals(
np.array([[1.0, 2.0], [3.0, 4.0]]),
np.array([[1.0, 2.0], [3.0, 4.0]]),
)