mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Implement Hdf5Sample
This commit is contained in:
@@ -5,6 +5,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
|
|
||||||
class Sample(ABC):
|
class Sample(ABC):
|
||||||
"""Abstract dictionary-like class that stores training data."""
|
"""Abstract dictionary-like class that stores training data."""
|
||||||
@@ -15,8 +19,29 @@ class Sample(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def put(self, key: str, value: Any) -> None:
|
def put(self, key: str, value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Add a new key/value pair to the sample. If the key already exists,
|
||||||
|
the previous value is silently replaced.
|
||||||
|
|
||||||
|
Only the following data types are supported:
|
||||||
|
- str, bool, int, float
|
||||||
|
- List[str], List[bool], List[int], List[float]
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _assert_supported(self, value: Any) -> None:
|
||||||
|
def _is_primitive(v: Any) -> bool:
|
||||||
|
if isinstance(v, (str, bool, int, float)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if _is_primitive(value):
|
||||||
|
return
|
||||||
|
if isinstance(value, list):
|
||||||
|
if _is_primitive(value[0]):
|
||||||
|
return
|
||||||
|
assert False, f"Value has unsupported type: {value}"
|
||||||
|
|
||||||
|
|
||||||
class MemorySample(Sample):
|
class MemorySample(Sample):
|
||||||
"""Dictionary-like class that stores training data in-memory."""
|
"""Dictionary-like class that stores training data in-memory."""
|
||||||
@@ -29,11 +54,47 @@ class MemorySample(Sample):
|
|||||||
data = {}
|
data = {}
|
||||||
self._data: Dict[str, Any] = data
|
self._data: Dict[str, Any] = data
|
||||||
|
|
||||||
|
@overrides
|
||||||
def get(self, key: str) -> Optional[Any]:
|
def get(self, key: str) -> Optional[Any]:
|
||||||
if key in self._data:
|
if key in self._data:
|
||||||
return self._data[key]
|
return self._data[key]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@overrides
|
||||||
def put(self, key: str, value: Any) -> None:
|
def put(self, key: str, value: Any) -> None:
|
||||||
|
# self._assert_supported(value)
|
||||||
self._data[key] = value
|
self._data[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
class Hdf5Sample(Sample):
|
||||||
|
"""
|
||||||
|
Dictionary-like class that stores training data in an HDF5 file.
|
||||||
|
|
||||||
|
Unlike MemorySample, this class only loads to memory the parts of the data set that
|
||||||
|
are actually accessed, and therefore it is more scalable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filename: str) -> None:
|
||||||
|
self.file = h5py.File(filename, "r+")
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def get(self, key: str) -> Optional[Any]:
|
||||||
|
ds = self.file[key]
|
||||||
|
if h5py.check_string_dtype(ds.dtype):
|
||||||
|
if ds.shape == ():
|
||||||
|
return ds.asstr()[()]
|
||||||
|
else:
|
||||||
|
return ds.asstr()[:].tolist()
|
||||||
|
else:
|
||||||
|
if ds.shape == ():
|
||||||
|
return ds[()].tolist()
|
||||||
|
else:
|
||||||
|
return ds[:].tolist()
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def put(self, key: str, value: Any) -> None:
|
||||||
|
self._assert_supported(value)
|
||||||
|
if key in self.file:
|
||||||
|
del self.file[key]
|
||||||
|
self.file.create_dataset(key, data=value)
|
||||||
|
|||||||
@@ -1,31 +1,44 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from miplearn.features.sample import MemorySample, Sample
|
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
|
||||||
|
|
||||||
|
|
||||||
def _test_sample(sample: Sample) -> None:
|
def _test_sample(sample: Sample) -> None:
|
||||||
# Strings
|
_assert_roundtrip(sample, "A")
|
||||||
sample.put("str", "hello")
|
_assert_roundtrip(sample, True)
|
||||||
assert sample.get("str") == "hello"
|
_assert_roundtrip(sample, 1)
|
||||||
|
_assert_roundtrip(sample, 1.0)
|
||||||
|
_assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"])
|
||||||
|
_assert_roundtrip(sample, [True, True, False])
|
||||||
|
_assert_roundtrip(sample, [1, 2, 3])
|
||||||
|
_assert_roundtrip(sample, [1.0, 2.0, 3.0])
|
||||||
|
|
||||||
# Numbers
|
|
||||||
sample.put("int", 1)
|
|
||||||
sample.put("float", 5.0)
|
|
||||||
assert sample.get("int") == 1
|
|
||||||
assert sample.get("float") == 5.0
|
|
||||||
|
|
||||||
# List of strings
|
def _assert_roundtrip(sample: Sample, expected: Any) -> None:
|
||||||
sample.put("strlist", ["hello", "world"])
|
sample.put("key", expected)
|
||||||
assert sample.get("strlist") == ["hello", "world"]
|
actual = sample.get("key")
|
||||||
|
assert actual == expected
|
||||||
# List of numbers
|
assert actual is not None
|
||||||
sample.put("intlist", [1, 2, 3])
|
if isinstance(actual, list):
|
||||||
sample.put("floatlist", [4.0, 5.0, 6.0])
|
assert isinstance(actual[0], expected[0].__class__), (
|
||||||
assert sample.get("intlist") == [1, 2, 3]
|
f"Expected class {expected[0].__class__}, "
|
||||||
assert sample.get("floatlist") == [4.0, 5.0, 6.0]
|
f"found {actual[0].__class__} instead"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(actual, expected.__class__), (
|
||||||
|
f"Expected class {expected.__class__}, "
|
||||||
|
f"found class {actual.__class__} instead"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_memory_sample() -> None:
|
def test_memory_sample() -> None:
|
||||||
_test_sample(MemorySample())
|
_test_sample(MemorySample())
|
||||||
|
|
||||||
|
|
||||||
|
def test_hdf5_sample() -> None:
|
||||||
|
file = NamedTemporaryFile()
|
||||||
|
_test_sample(Hdf5Sample(file.name))
|
||||||
|
|||||||
Reference in New Issue
Block a user