From 0a399deeee44341e5cedf786a6c0ac6204756405 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Wed, 14 Jul 2021 09:56:25 -0500 Subject: [PATCH] Implement Hdf5Sample --- miplearn/features/sample.py | 61 +++++++++++++++++++++++++++++++++++ tests/features/test_sample.py | 53 ++++++++++++++++++------------ 2 files changed, 94 insertions(+), 20 deletions(-) diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 2fd8e9f..bd23d9b 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -5,6 +5,10 @@ from abc import ABC, abstractmethod from typing import Dict, Optional, Any +import h5py +import numpy as np +from overrides import overrides + class Sample(ABC): """Abstract dictionary-like class that stores training data.""" @@ -15,8 +19,29 @@ class Sample(ABC): @abstractmethod def put(self, key: str, value: Any) -> None: + """ + Add a new key/value pair to the sample. If the key already exists, + the previous value is silently replaced. + + Only the following data types are supported: + - str, bool, int, float + - List[str], List[bool], List[int], List[float] + """ pass + def _assert_supported(self, value: Any) -> None: + def _is_primitive(v: Any) -> bool: + if isinstance(v, (str, bool, int, float)): + return True + return False + + if _is_primitive(value): + return + if isinstance(value, list): + if _is_primitive(value[0]): + return + assert False, f"Value has unsupported type: {value}" + class MemorySample(Sample): """Dictionary-like class that stores training data in-memory.""" @@ -29,11 +54,47 @@ class MemorySample(Sample): data = {} self._data: Dict[str, Any] = data + @overrides def get(self, key: str) -> Optional[Any]: if key in self._data: return self._data[key] else: return None + @overrides def put(self, key: str, value: Any) -> None: + # self._assert_supported(value) self._data[key] = value + + +class Hdf5Sample(Sample): + """ + Dictionary-like class that stores training data in an HDF5 file. + + Unlike MemorySample, this class only loads to memory the parts of the data set that + are actually accessed, and therefore it is more scalable. + """ + + def __init__(self, filename: str) -> None: + self.file = h5py.File(filename, "r+") + + @overrides + def get(self, key: str) -> Optional[Any]: + ds = self.file[key] + if h5py.check_string_dtype(ds.dtype): + if ds.shape == (): + return ds.asstr()[()] + else: + return ds.asstr()[:].tolist() + else: + if ds.shape == (): + return ds[()].tolist() + else: + return ds[:].tolist() + + @overrides + def put(self, key: str, value: Any) -> None: + self._assert_supported(value) + if key in self.file: + del self.file[key] + self.file.create_dataset(key, data=value) diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 5bd869b..3cdb5e7 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -1,31 +1,44 @@ # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. +from tempfile import NamedTemporaryFile +from typing import Any -from miplearn.features.sample import MemorySample, Sample +from miplearn.features.sample import MemorySample, Sample, Hdf5Sample def _test_sample(sample: Sample) -> None: - # Strings - sample.put("str", "hello") - assert sample.get("str") == "hello" - - # Numbers - sample.put("int", 1) - sample.put("float", 5.0) - assert sample.get("int") == 1 - assert sample.get("float") == 5.0 - - # List of strings - sample.put("strlist", ["hello", "world"]) - assert sample.get("strlist") == ["hello", "world"] - - # List of numbers - sample.put("intlist", [1, 2, 3]) - sample.put("floatlist", [4.0, 5.0, 6.0]) - assert sample.get("intlist") == [1, 2, 3] - assert sample.get("floatlist") == [4.0, 5.0, 6.0] + _assert_roundtrip(sample, "A") + _assert_roundtrip(sample, True) + _assert_roundtrip(sample, 1) + _assert_roundtrip(sample, 1.0) + _assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"]) + _assert_roundtrip(sample, [True, True, False]) + _assert_roundtrip(sample, [1, 2, 3]) + _assert_roundtrip(sample, [1.0, 2.0, 3.0]) + + +def _assert_roundtrip(sample: Sample, expected: Any) -> None: + sample.put("key", expected) + actual = sample.get("key") + assert actual == expected + assert actual is not None + if isinstance(actual, list): + assert isinstance(actual[0], expected[0].__class__), ( + f"Expected class {expected[0].__class__}, " + f"found {actual[0].__class__} instead" + ) + else: + assert isinstance(actual, expected.__class__), ( + f"Expected class {expected.__class__}, " + f"found class {actual.__class__} instead" + ) def test_memory_sample() -> None: _test_sample(MemorySample()) + + +def test_hdf5_sample() -> None: + file = NamedTemporaryFile() + _test_sample(Hdf5Sample(file.name))