# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. from tempfile import NamedTemporaryFile from typing import Any import numpy as np from scipy.sparse import coo_matrix from miplearn.features.sample import MemorySample, Sample, Hdf5Sample def test_memory_sample() -> None: _test_sample(MemorySample()) def test_hdf5_sample() -> None: file = NamedTemporaryFile() _test_sample(Hdf5Sample(file.name)) def _test_sample(sample: Sample) -> None: _assert_roundtrip_scalar(sample, "A") _assert_roundtrip_scalar(sample, True) _assert_roundtrip_scalar(sample, 1) _assert_roundtrip_scalar(sample, 1.0) assert sample.get_scalar("unknown-key") is None _assert_roundtrip_array(sample, np.array([True, False])) _assert_roundtrip_array(sample, np.array([1, 2, 3])) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0])) _assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S")) assert sample.get_array("unknown-key") is None _assert_roundtrip_sparse( sample, coo_matrix( [ [1.0, 0.0, 0.0], [0.0, 2.0, 3.0], [0.0, 0.0, 4.0], ], ), ) assert sample.get_sparse("unknown-key") is None def _assert_roundtrip_array(sample: Sample, original: np.ndarray) -> None: sample.put_array("key", original) recovered = sample.get_array("key") assert recovered is not None assert isinstance(recovered, np.ndarray) assert (recovered == original).all() def _assert_roundtrip_scalar(sample: Sample, original: Any) -> None: sample.put_scalar("key", original) recovered = sample.get_scalar("key") assert recovered == original assert recovered is not None assert isinstance( recovered, original.__class__ ), f"Expected {original.__class__}, found {recovered.__class__} instead" def _assert_roundtrip_sparse(sample: Sample, original: coo_matrix) -> None: sample.put_sparse("key", original) recovered = sample.get_sparse("key") assert recovered is not None assert isinstance(recovered, coo_matrix) assert (original != recovered).sum() == 0