Implement Hdf5Sample

master
Alinson S. Xavier 4 years ago
parent 021a71f60c
commit 0a399deeee
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -5,6 +5,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional, Any from typing import Dict, Optional, Any
import h5py
import numpy as np
from overrides import overrides
class Sample(ABC): class Sample(ABC):
"""Abstract dictionary-like class that stores training data.""" """Abstract dictionary-like class that stores training data."""
@ -15,8 +19,29 @@ class Sample(ABC):
@abstractmethod @abstractmethod
def put(self, key: str, value: Any) -> None: def put(self, key: str, value: Any) -> None:
"""
Add a new key/value pair to the sample. If the key already exists,
the previous value is silently replaced.
Only the following data types are supported:
- str, bool, int, float
- List[str], List[bool], List[int], List[float]
"""
pass pass
def _assert_supported(self, value: Any) -> None:
def _is_primitive(v: Any) -> bool:
if isinstance(v, (str, bool, int, float)):
return True
return False
if _is_primitive(value):
return
if isinstance(value, list):
if _is_primitive(value[0]):
return
assert False, f"Value has unsupported type: {value}"
class MemorySample(Sample): class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory.""" """Dictionary-like class that stores training data in-memory."""
@ -29,11 +54,47 @@ class MemorySample(Sample):
data = {} data = {}
self._data: Dict[str, Any] = data self._data: Dict[str, Any] = data
@overrides
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Optional[Any]:
if key in self._data: if key in self._data:
return self._data[key] return self._data[key]
else: else:
return None return None
@overrides
def put(self, key: str, value: Any) -> None: def put(self, key: str, value: Any) -> None:
# self._assert_supported(value)
self._data[key] = value self._data[key] = value
class Hdf5Sample(Sample):
"""
Dictionary-like class that stores training data in an HDF5 file.
Unlike MemorySample, this class only loads to memory the parts of the data set that
are actually accessed, and therefore it is more scalable.
"""
def __init__(self, filename: str) -> None:
self.file = h5py.File(filename, "r+")
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
if h5py.check_string_dtype(ds.dtype):
if ds.shape == ():
return ds.asstr()[()]
else:
return ds.asstr()[:].tolist()
else:
if ds.shape == ():
return ds[()].tolist()
else:
return ds[:].tolist()
@overrides
def put(self, key: str, value: Any) -> None:
self._assert_supported(value)
if key in self.file:
del self.file[key]
self.file.create_dataset(key, data=value)

@ -1,31 +1,44 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile
from typing import Any
from miplearn.features.sample import MemorySample, Sample from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
def _test_sample(sample: Sample) -> None: def _test_sample(sample: Sample) -> None:
# Strings _assert_roundtrip(sample, "A")
sample.put("str", "hello") _assert_roundtrip(sample, True)
assert sample.get("str") == "hello" _assert_roundtrip(sample, 1)
_assert_roundtrip(sample, 1.0)
# Numbers _assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"])
sample.put("int", 1) _assert_roundtrip(sample, [True, True, False])
sample.put("float", 5.0) _assert_roundtrip(sample, [1, 2, 3])
assert sample.get("int") == 1 _assert_roundtrip(sample, [1.0, 2.0, 3.0])
assert sample.get("float") == 5.0
# List of strings def _assert_roundtrip(sample: Sample, expected: Any) -> None:
sample.put("strlist", ["hello", "world"]) sample.put("key", expected)
assert sample.get("strlist") == ["hello", "world"] actual = sample.get("key")
assert actual == expected
# List of numbers assert actual is not None
sample.put("intlist", [1, 2, 3]) if isinstance(actual, list):
sample.put("floatlist", [4.0, 5.0, 6.0]) assert isinstance(actual[0], expected[0].__class__), (
assert sample.get("intlist") == [1, 2, 3] f"Expected class {expected[0].__class__}, "
assert sample.get("floatlist") == [4.0, 5.0, 6.0] f"found {actual[0].__class__} instead"
)
else:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
)
def test_memory_sample() -> None: def test_memory_sample() -> None:
_test_sample(MemorySample()) _test_sample(MemorySample())
def test_hdf5_sample() -> None:
file = NamedTemporaryFile()
_test_sample(Hdf5Sample(file.name))

Loading…
Cancel
Save