Implement {get,put}_vector_list

master
Alinson S. Xavier 4 years ago
parent 8fc7c6ab71
commit 8d89285cb9
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -46,7 +46,7 @@ class FeaturesExtractor:
self._extract_user_features_constrs(instance, sample) self._extract_user_features_constrs(instance, sample)
self._extract_user_features_instance(instance, sample) self._extract_user_features_instance(instance, sample)
self._extract_var_features_AlvLouWeh2017(sample) self._extract_var_features_AlvLouWeh2017(sample)
sample.put( sample.put_vector_list(
"var_features", "var_features",
self._combine( self._combine(
sample, sample,
@ -82,7 +82,7 @@ class FeaturesExtractor:
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up) sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks) sample.put_vector("lp_constr_slacks", constraints.slacks)
self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_") self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_")
sample.put( sample.put_vector_list(
"lp_var_features", "lp_var_features",
self._combine( self._combine(
sample, sample,
@ -103,7 +103,7 @@ class FeaturesExtractor:
], ],
), ),
) )
sample.put( sample.put_vector_list(
"lp_constr_features", "lp_constr_features",
self._combine( self._combine(
sample, sample,
@ -118,7 +118,7 @@ class FeaturesExtractor:
) )
instance_features_user = sample.get("instance_features_user") instance_features_user = sample.get("instance_features_user")
assert instance_features_user is not None assert instance_features_user is not None
sample.put( sample.put_vector(
"lp_instance_features", "lp_instance_features",
instance_features_user instance_features_user
+ [ + [
@ -178,7 +178,7 @@ class FeaturesExtractor:
user_features_i = list(user_features_i) user_features_i = list(user_features_i)
user_features.append(user_features_i) user_features.append(user_features_i)
sample.put("var_categories", categories) sample.put("var_categories", categories)
sample.put("var_features_user", user_features) sample.put_vector_list("var_features_user", user_features)
def _extract_user_features_constrs( def _extract_user_features_constrs(
self, self,
@ -227,7 +227,7 @@ class FeaturesExtractor:
lazy.append(instance.is_constraint_lazy(cname)) lazy.append(instance.is_constraint_lazy(cname))
else: else:
lazy.append(False) lazy.append(False)
sample.put("constr_features_user", user_features) sample.put_vector_list("constr_features_user", user_features)
sample.put_vector("constr_lazy", lazy) sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories) sample.put("constr_categories", categories)
@ -250,7 +250,7 @@ class FeaturesExtractor:
) )
constr_lazy = sample.get("constr_lazy") constr_lazy = sample.get("constr_lazy")
assert constr_lazy is not None assert constr_lazy is not None
sample.put("instance_features_user", user_features) sample.put_vector("instance_features_user", user_features)
sample.put_scalar("static_lazy_count", sum(constr_lazy)) sample.put_scalar("static_lazy_count", sum(constr_lazy))
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based # Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
@ -331,7 +331,7 @@ class FeaturesExtractor:
for v in f: for v in f:
assert isfinite(v), f"non-finite elements detected: {f}" assert isfinite(v), f"non-finite elements detected: {f}"
features.append(f) features.append(f)
sample.put(f"{prefix}var_features_AlvLouWeh2017", features) sample.put_vector_list(f"{prefix}var_features_AlvLouWeh2017", features)
def _combine( def _combine(
self, self,

@ -3,13 +3,25 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional, Any, Union, List from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast
import h5py import h5py
import numpy as np
from overrides import overrides from overrides import overrides
Scalar = Union[None, bool, str, int, float] Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]] Vector = Union[None, List[bool], List[str], List[int], List[float]]
VectorList = Union[
List[List[bool]],
List[List[str]],
List[List[int]],
List[List[float]],
List[Optional[List[bool]]],
List[Optional[List[str]]],
List[Optional[List[int]]],
List[Optional[List[float]]],
]
class Sample(ABC): class Sample(ABC):
@ -31,6 +43,14 @@ class Sample(ABC):
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
pass pass
@abstractmethod
def get_vector_list(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_vector_list(self, key: str, value: VectorList) -> None:
pass
@abstractmethod @abstractmethod
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Optional[Any]:
pass pass
@ -65,17 +85,24 @@ class Sample(ABC):
return return
assert False, f"Value has unsupported type: {value}" assert False, f"Value has unsupported type: {value}"
def _assert_scalar(self, value: Any) -> None: def _assert_is_scalar(self, value: Any) -> None:
if value is None: if value is None:
return return
if isinstance(value, (str, bool, int, float)): if isinstance(value, (str, bool, int, float)):
return return
assert False, f"Scalar expected; found instead: {value}" assert False, f"Scalar expected; found instead: {value}"
def _assert_vector(self, value: Any) -> None: def _assert_is_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}" assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value: for v in value:
self._assert_scalar(v) self._assert_is_scalar(v)
def _assert_is_vector_list(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
if v is None:
continue
self._assert_is_vector(v)
class MemorySample(Sample): class MemorySample(Sample):
@ -94,19 +121,28 @@ class MemorySample(Sample):
return self.get(key) return self.get(key)
@overrides @overrides
def put_scalar(self, key: str, value: Scalar) -> None: def get_vector(self, key: str) -> Optional[Any]:
self._assert_scalar(value) return self.get(key)
self.put(key, value)
@overrides @overrides
def get_vector(self, key: str) -> Optional[Any]: def get_vector_list(self, key: str) -> Optional[Any]:
return self.get(key) return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value)
self.put(key, value)
@overrides @overrides
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
if value is None: if value is None:
return return
self._assert_vector(value) self._assert_is_vector(value)
self.put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
self.put(key, value) self.put(key, value)
@overrides @overrides
@ -145,23 +181,55 @@ class Hdf5Sample(Sample):
def get_vector(self, key: str) -> Optional[Any]: def get_vector(self, key: str) -> Optional[Any]:
ds = self.file[key] ds = self.file[key]
assert len(ds.shape) == 1 assert len(ds.shape) == 1
print(ds.dtype)
if h5py.check_string_dtype(ds.dtype): if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist() return ds.asstr()[:].tolist()
else: else:
return ds[:].tolist() return ds[:].tolist()
@overrides
def get_vector_list(self, key: str) -> Optional[Any]:
ds = self.file[key]
lens = ds.attrs["lengths"]
if h5py.check_string_dtype(ds.dtype):
padded = ds.asstr()[:].tolist()
else:
padded = ds[:].tolist()
return _crop(padded, lens)
@overrides @overrides
def put_scalar(self, key: str, value: Any) -> None: def put_scalar(self, key: str, value: Any) -> None:
self._assert_scalar(value) self._assert_is_scalar(value)
self.put(key, value) self.put(key, value)
@overrides @overrides
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
if value is None: if value is None:
return return
self._assert_vector(value) self._assert_is_vector(value)
self.put(key, value) self.put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
if key in self.file:
del self.file[key]
padded, lens = _pad(value)
data = None
for v in value:
if v is None or len(v) == 0:
continue
if isinstance(v[0], str):
data = np.array(padded, dtype="S")
elif isinstance(v[0], bool):
data = np.array(padded, dtype=bool)
else:
data = np.array(padded)
break
assert data is not None
ds = self.file.create_dataset(key, data=data)
ds.attrs["lengths"] = lens
@overrides @overrides
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Optional[Any]:
ds = self.file[key] ds = self.file[key]
@ -175,3 +243,45 @@ class Hdf5Sample(Sample):
if key in self.file: if key in self.file:
del self.file[key] del self.file[key]
self.file.create_dataset(key, data=value) self.file.create_dataset(key, data=value)
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist)
lens = [len(v) if v is not None else -1 for v in veclist]
maxlen = max(lens)
# Find appropriate constant to pad the vectors
constant: Union[int, float, str, None] = None
for v in veclist:
if v is None or len(v) == 0:
continue
if isinstance(v[0], int):
constant = 0
elif isinstance(v[0], float):
constant = 0.0
elif isinstance(v[0], str):
constant = ""
else:
assert False, f"Unsupported data type: {v[0]}"
assert constant is not None, "veclist must not be completely empty"
# Pad vectors
for (i, vi) in enumerate(veclist):
if vi is None:
vi = veclist[i] = []
assert isinstance(vi, list)
for k in range(len(vi), maxlen):
vi.append(constant)
return veclist, lens
def _crop(veclist: VectorList, lens: List[int]) -> VectorList:
result: VectorList = cast(VectorList, [])
for (i, v) in enumerate(veclist):
if lens[i] < 0:
result.append(None) # type: ignore
else:
assert isinstance(v, list)
result.append(v[: lens[i]])
return result

@ -4,7 +4,7 @@
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
def test_memory_sample() -> None: def test_memory_sample() -> None:
@ -29,16 +29,11 @@ def _test_sample(sample: Sample) -> None:
_assert_roundtrip_vector(sample, [1, 2, 3]) _assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0]) _assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
# List[Optional[List[Primitive]]] # VectorList
# _assert_roundtrip( _assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
# sample, _assert_roundtrip_vector_list(sample, [[True], [False, False], None])
# [ _assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
# [1], _assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
# None,
# [2, 2],
# [3, 3, 3],
# ],
# )
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
@ -57,8 +52,76 @@ def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None:
_assert_same_type(actual[0], expected[0]) _assert_same_type(actual[0], expected[0])
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:
sample.put_vector_list("key", expected)
actual = sample.get_vector_list("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual[0][0], expected[0][0])
def _assert_same_type(actual: Any, expected: Any) -> None: def _assert_same_type(actual: Any, expected: Any) -> None:
assert isinstance(actual, expected.__class__), ( assert isinstance(
f"Expected class {expected.__class__}, " actual, expected.__class__
f"found class {actual.__class__} instead" ), f"Expected {expected.__class__}, found {actual.__class__} instead"
def test_pad_int() -> None:
_assert_roundtrip_pad(
original=[[1], [2, 2, 2], [], [3, 3], [4, 4, 4, 4], None],
expected_padded=[
[1, 0, 0, 0],
[2, 2, 2, 0],
[0, 0, 0, 0],
[3, 3, 0, 0],
[4, 4, 4, 4],
[0, 0, 0, 0],
],
expected_lens=[1, 3, 0, 2, 4, -1],
dtype=int,
)
def test_pad_float() -> None:
_assert_roundtrip_pad(
original=[[1.0], [2.0, 2.0, 2.0], [3.0, 3.0], [4.0, 4.0, 4.0, 4.0], None],
expected_padded=[
[1.0, 0.0, 0.0, 0.0],
[2.0, 2.0, 2.0, 0.0],
[3.0, 3.0, 0.0, 0.0],
[4.0, 4.0, 4.0, 4.0],
[0.0, 0.0, 0.0, 0.0],
],
expected_lens=[1, 3, 2, 4, -1],
dtype=float,
) )
def test_pad_str() -> None:
_assert_roundtrip_pad(
original=[["A"], ["B", "B", "B"], ["C", "C"]],
expected_padded=[["A", "", ""], ["B", "B", "B"], ["C", "C", ""]],
expected_lens=[1, 3, 2],
dtype=str,
)
def _assert_roundtrip_pad(
original: Any,
expected_padded: Any,
expected_lens: Any,
dtype: Any,
) -> None:
actual_padded, actual_lens = _pad(original)
assert actual_padded == expected_padded
assert actual_lens == expected_lens
for v in actual_padded:
for vi in v: # type: ignore
assert isinstance(vi, dtype)
cropped = _crop(actual_padded, actual_lens)
assert cropped == original
for v in cropped:
if v is None:
continue
for vi in v: # type: ignore
assert isinstance(vi, dtype)

Loading…
Cancel
Save