diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index 28163d0..b0ff220 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -15,24 +15,6 @@ inf = float("inf") # This file is in the main source folder, so that it can be called from Julia. -def _round(obj: Any) -> Any: - if obj is None: - return None - if isinstance(obj, float): - return round(obj, 6) - if isinstance(obj, tuple): - return tuple([_round(v) for v in obj]) - if isinstance(obj, list): - return [_round(v) for v in obj] - if isinstance(obj, dict): - return {key: _round(value) for (key, value) in obj.items()} - if isinstance(obj, VariableFeatures): - obj.__dict__ = _round(obj.__dict__) - if isinstance(obj, ConstraintFeatures): - obj.__dict__ = _round(obj.__dict__) - return obj - - def _filter_attrs(allowed_keys: List[str], obj: Any) -> Any: for key in obj.__dict__.keys(): if key not in allowed_keys: @@ -98,7 +80,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # Fetch variables (after-lp) assert_equals( - _round(solver.get_variables(with_static=False)), + solver.get_variables(with_static=False), _filter_attrs( solver.get_variable_attrs(), VariableFeatures( @@ -118,7 +100,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # Fetch constraints (after-lp) assert_equals( - _round(solver.get_constraints(with_static=False)), + solver.get_constraints(with_static=False), _filter_attrs( solver.get_constraint_attrs(), ConstraintFeatures( @@ -151,7 +133,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # Fetch variables (after-mip) assert_equals( - _round(solver.get_variables(with_static=False)), + solver.get_variables(with_static=False), _filter_attrs( solver.get_variable_attrs(), VariableFeatures( @@ -163,7 +145,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # Fetch constraints (after-mip) assert_equals( - _round(solver.get_constraints(with_static=False)), + solver.get_constraints(with_static=False), _filter_attrs( solver.get_constraint_attrs(), ConstraintFeatures( @@ -185,7 +167,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: # Add constraint and verify it affects solution solver.add_constraints(cf) assert_equals( - _round(solver.get_constraints(with_static=True)), + solver.get_constraints(with_static=True), _filter_attrs( solver.get_constraint_attrs(), ConstraintFeatures( @@ -283,26 +265,28 @@ def run_lazy_cb_tests(solver: InternalSolver) -> None: assert_equals(solution["x[0]"], 0.0) -def _recursive_convert_ndarray_to_list(obj: Any) -> Any: +def _equals_preprocess(obj: Any) -> Any: if isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, (int, float, str)): + return np.round(obj, decimals=6).tolist() + elif isinstance(obj, (int, str)): return obj + elif isinstance(obj, float): + return round(obj, 6) elif isinstance(obj, list): - return [_recursive_convert_ndarray_to_list(i) for i in obj] + return [_equals_preprocess(i) for i in obj] elif isinstance(obj, tuple): - return tuple(_recursive_convert_ndarray_to_list(i) for i in obj) + return tuple(_equals_preprocess(i) for i in obj) elif obj is None: return None elif isinstance(obj, dict): - return {k: _recursive_convert_ndarray_to_list(v) for (k, v) in obj.items()} + return {k: _equals_preprocess(v) for (k, v) in obj.items()} else: for key in obj.__dict__.keys(): - obj.__dict__[key] = _recursive_convert_ndarray_to_list(obj.__dict__[key]) + obj.__dict__[key] = _equals_preprocess(obj.__dict__[key]) return obj def assert_equals(left: Any, right: Any) -> None: - left = _recursive_convert_ndarray_to_list(left) - right = _recursive_convert_ndarray_to_list(right) + left = _equals_preprocess(left) + right = _equals_preprocess(right) assert left == right, f"left:\n{left}\nright:\n{right}" diff --git a/tests/solvers/test_learning_solver.py b/tests/solvers/test_learning_solver.py index f453b89..85b33a8 100644 --- a/tests/solvers/test_learning_solver.py +++ b/tests/solvers/test_learning_solver.py @@ -16,8 +16,8 @@ from miplearn.solvers.internal import InternalSolver from miplearn.solvers.learning import LearningSolver # noinspection PyUnresolvedReferences -from miplearn.solvers.tests import _round from tests.solvers.test_internal_solver import internal_solvers +from miplearn.solvers.tests import assert_equals logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ def test_learning_solver( after_lp = sample.after_lp assert after_lp is not None assert after_lp.variables is not None - assert _round(after_lp.variables.values) == [1.0, 0.923077, 1.0, 0.0, 67.0] + assert_equals(after_lp.variables.values, [1.0, 0.923077, 1.0, 0.0, 67.0]) assert after_lp.lp_solve is not None assert after_lp.lp_solve.lp_value is not None assert round(after_lp.lp_solve.lp_value, 3) == 1287.923 diff --git a/tests/test_features.py b/tests/test_features.py index be83b79..446f895 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -2,6 +2,8 @@ # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. +import numpy as np + from miplearn.features import ( FeaturesExtractor, InstanceFeatures, @@ -9,11 +11,7 @@ from miplearn.features import ( ConstraintFeatures, ) from miplearn.solvers.gurobi import GurobiSolver -from miplearn.solvers.tests import ( - assert_equals, - _round, -) -import numpy as np +from miplearn.solvers.tests import assert_equals inf = float("inf") @@ -30,7 +28,7 @@ def test_knapsack() -> None: assert features.instance is not None assert_equals( - _round(features.variables), + features.variables, VariableFeatures( names=["x[0]", "x[1]", "x[2]", "x[3]", "z"], basis_status=["U", "B", "U", "L", "U"], @@ -64,7 +62,7 @@ def test_knapsack() -> None: ), ) assert_equals( - _round(features.constraints), + features.constraints, ConstraintFeatures( basis_status=["N"], categories=["eq_capacity"],