Implement Hdf5Sample

This commit is contained in:
2021-07-14 09:56:25 -05:00
parent 021a71f60c
commit 0a399deeee
2 changed files with 92 additions and 18 deletions

View File

@@ -1,31 +1,44 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile
from typing import Any
from miplearn.features.sample import MemorySample, Sample
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
def _test_sample(sample: Sample) -> None:
# Strings
sample.put("str", "hello")
assert sample.get("str") == "hello"
_assert_roundtrip(sample, "A")
_assert_roundtrip(sample, True)
_assert_roundtrip(sample, 1)
_assert_roundtrip(sample, 1.0)
_assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"])
_assert_roundtrip(sample, [True, True, False])
_assert_roundtrip(sample, [1, 2, 3])
_assert_roundtrip(sample, [1.0, 2.0, 3.0])
# Numbers
sample.put("int", 1)
sample.put("float", 5.0)
assert sample.get("int") == 1
assert sample.get("float") == 5.0
# List of strings
sample.put("strlist", ["hello", "world"])
assert sample.get("strlist") == ["hello", "world"]
# List of numbers
sample.put("intlist", [1, 2, 3])
sample.put("floatlist", [4.0, 5.0, 6.0])
assert sample.get("intlist") == [1, 2, 3]
assert sample.get("floatlist") == [4.0, 5.0, 6.0]
def _assert_roundtrip(sample: Sample, expected: Any) -> None:
sample.put("key", expected)
actual = sample.get("key")
assert actual == expected
assert actual is not None
if isinstance(actual, list):
assert isinstance(actual[0], expected[0].__class__), (
f"Expected class {expected[0].__class__}, "
f"found {actual[0].__class__} instead"
)
else:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
)
def test_memory_sample() -> None:
_test_sample(MemorySample())
def test_hdf5_sample() -> None:
file = NamedTemporaryFile()
_test_sample(Hdf5Sample(file.name))