Implement {get,put}_vector_list

master
Alinson S. Xavier 4 years ago
parent 8fc7c6ab71
commit 8d89285cb9
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -46,7 +46,7 @@ class FeaturesExtractor:
self._extract_user_features_constrs(instance, sample)
self._extract_user_features_instance(instance, sample)
self._extract_var_features_AlvLouWeh2017(sample)
sample.put(
sample.put_vector_list(
"var_features",
self._combine(
sample,
@ -82,7 +82,7 @@ class FeaturesExtractor:
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks)
self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_")
sample.put(
sample.put_vector_list(
"lp_var_features",
self._combine(
sample,
@ -103,7 +103,7 @@ class FeaturesExtractor:
],
),
)
sample.put(
sample.put_vector_list(
"lp_constr_features",
self._combine(
sample,
@ -118,7 +118,7 @@ class FeaturesExtractor:
)
instance_features_user = sample.get("instance_features_user")
assert instance_features_user is not None
sample.put(
sample.put_vector(
"lp_instance_features",
instance_features_user
+ [
@ -178,7 +178,7 @@ class FeaturesExtractor:
user_features_i = list(user_features_i)
user_features.append(user_features_i)
sample.put("var_categories", categories)
sample.put("var_features_user", user_features)
sample.put_vector_list("var_features_user", user_features)
def _extract_user_features_constrs(
self,
@ -227,7 +227,7 @@ class FeaturesExtractor:
lazy.append(instance.is_constraint_lazy(cname))
else:
lazy.append(False)
sample.put("constr_features_user", user_features)
sample.put_vector_list("constr_features_user", user_features)
sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories)
@ -250,7 +250,7 @@ class FeaturesExtractor:
)
constr_lazy = sample.get("constr_lazy")
assert constr_lazy is not None
sample.put("instance_features_user", user_features)
sample.put_vector("instance_features_user", user_features)
sample.put_scalar("static_lazy_count", sum(constr_lazy))
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
@ -331,7 +331,7 @@ class FeaturesExtractor:
for v in f:
assert isfinite(v), f"non-finite elements detected: {f}"
features.append(f)
sample.put(f"{prefix}var_features_AlvLouWeh2017", features)
sample.put_vector_list(f"{prefix}var_features_AlvLouWeh2017", features)
def _combine(
self,

@ -3,13 +3,25 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any, Union, List
from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast
import h5py
import numpy as np
from overrides import overrides
Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]]
VectorList = Union[
List[List[bool]],
List[List[str]],
List[List[int]],
List[List[float]],
List[Optional[List[bool]]],
List[Optional[List[str]]],
List[Optional[List[int]]],
List[Optional[List[float]]],
]
class Sample(ABC):
@ -31,6 +43,14 @@ class Sample(ABC):
def put_vector(self, key: str, value: Vector) -> None:
pass
@abstractmethod
def get_vector_list(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_vector_list(self, key: str, value: VectorList) -> None:
pass
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@ -65,17 +85,24 @@ class Sample(ABC):
return
assert False, f"Value has unsupported type: {value}"
def _assert_scalar(self, value: Any) -> None:
def _assert_is_scalar(self, value: Any) -> None:
if value is None:
return
if isinstance(value, (str, bool, int, float)):
return
assert False, f"Scalar expected; found instead: {value}"
def _assert_vector(self, value: Any) -> None:
def _assert_is_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
self._assert_scalar(v)
self._assert_is_scalar(v)
def _assert_is_vector_list(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
if v is None:
continue
self._assert_is_vector(v)
class MemorySample(Sample):
@ -94,19 +121,28 @@ class MemorySample(Sample):
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_scalar(value)
self.put(key, value)
def get_vector(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def get_vector(self, key: str) -> Optional[Any]:
def get_vector_list(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self._assert_is_vector(value)
self.put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
self.put(key, value)
@overrides
@ -145,23 +181,55 @@ class Hdf5Sample(Sample):
def get_vector(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 1
print(ds.dtype)
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
return ds[:].tolist()
@overrides
def get_vector_list(self, key: str) -> Optional[Any]:
ds = self.file[key]
lens = ds.attrs["lengths"]
if h5py.check_string_dtype(ds.dtype):
padded = ds.asstr()[:].tolist()
else:
padded = ds[:].tolist()
return _crop(padded, lens)
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_scalar(value)
self._assert_is_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self._assert_is_vector(value)
self.put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
if key in self.file:
del self.file[key]
padded, lens = _pad(value)
data = None
for v in value:
if v is None or len(v) == 0:
continue
if isinstance(v[0], str):
data = np.array(padded, dtype="S")
elif isinstance(v[0], bool):
data = np.array(padded, dtype=bool)
else:
data = np.array(padded)
break
assert data is not None
ds = self.file.create_dataset(key, data=data)
ds.attrs["lengths"] = lens
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
@ -175,3 +243,45 @@ class Hdf5Sample(Sample):
if key in self.file:
del self.file[key]
self.file.create_dataset(key, data=value)
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist)
lens = [len(v) if v is not None else -1 for v in veclist]
maxlen = max(lens)
# Find appropriate constant to pad the vectors
constant: Union[int, float, str, None] = None
for v in veclist:
if v is None or len(v) == 0:
continue
if isinstance(v[0], int):
constant = 0
elif isinstance(v[0], float):
constant = 0.0
elif isinstance(v[0], str):
constant = ""
else:
assert False, f"Unsupported data type: {v[0]}"
assert constant is not None, "veclist must not be completely empty"
# Pad vectors
for (i, vi) in enumerate(veclist):
if vi is None:
vi = veclist[i] = []
assert isinstance(vi, list)
for k in range(len(vi), maxlen):
vi.append(constant)
return veclist, lens
def _crop(veclist: VectorList, lens: List[int]) -> VectorList:
result: VectorList = cast(VectorList, [])
for (i, v) in enumerate(veclist):
if lens[i] < 0:
result.append(None) # type: ignore
else:
assert isinstance(v, list)
result.append(v[: lens[i]])
return result

@ -4,7 +4,7 @@
from tempfile import NamedTemporaryFile
from typing import Any
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
def test_memory_sample() -> None:
@ -29,16 +29,11 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
# List[Optional[List[Primitive]]]
# _assert_roundtrip(
# sample,
# [
# [1],
# None,
# [2, 2],
# [3, 3, 3],
# ],
# )
# VectorList
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
_assert_roundtrip_vector_list(sample, [[True], [False, False], None])
_assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
_assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
@ -57,8 +52,76 @@ def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None:
_assert_same_type(actual[0], expected[0])
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:
sample.put_vector_list("key", expected)
actual = sample.get_vector_list("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual[0][0], expected[0][0])
def _assert_same_type(actual: Any, expected: Any) -> None:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
assert isinstance(
actual, expected.__class__
), f"Expected {expected.__class__}, found {actual.__class__} instead"
def test_pad_int() -> None:
_assert_roundtrip_pad(
original=[[1], [2, 2, 2], [], [3, 3], [4, 4, 4, 4], None],
expected_padded=[
[1, 0, 0, 0],
[2, 2, 2, 0],
[0, 0, 0, 0],
[3, 3, 0, 0],
[4, 4, 4, 4],
[0, 0, 0, 0],
],
expected_lens=[1, 3, 0, 2, 4, -1],
dtype=int,
)
def test_pad_float() -> None:
_assert_roundtrip_pad(
original=[[1.0], [2.0, 2.0, 2.0], [3.0, 3.0], [4.0, 4.0, 4.0, 4.0], None],
expected_padded=[
[1.0, 0.0, 0.0, 0.0],
[2.0, 2.0, 2.0, 0.0],
[3.0, 3.0, 0.0, 0.0],
[4.0, 4.0, 4.0, 4.0],
[0.0, 0.0, 0.0, 0.0],
],
expected_lens=[1, 3, 2, 4, -1],
dtype=float,
)
def test_pad_str() -> None:
_assert_roundtrip_pad(
original=[["A"], ["B", "B", "B"], ["C", "C"]],
expected_padded=[["A", "", ""], ["B", "B", "B"], ["C", "C", ""]],
expected_lens=[1, 3, 2],
dtype=str,
)
def _assert_roundtrip_pad(
original: Any,
expected_padded: Any,
expected_lens: Any,
dtype: Any,
) -> None:
actual_padded, actual_lens = _pad(original)
assert actual_padded == expected_padded
assert actual_lens == expected_lens
for v in actual_padded:
for vi in v: # type: ignore
assert isinstance(vi, dtype)
cropped = _crop(actual_padded, actual_lens)
assert cropped == original
for v in cropped:
if v is None:
continue
for vi in v: # type: ignore
assert isinstance(vi, dtype)

Loading…
Cancel
Save