Implement {get,put}_vector_list

This commit is contained in:
2021-07-14 12:21:09 -05:00
parent 8fc7c6ab71
commit 8d89285cb9
3 changed files with 207 additions and 34 deletions

View File

@@ -46,7 +46,7 @@ class FeaturesExtractor:
self._extract_user_features_constrs(instance, sample)
self._extract_user_features_instance(instance, sample)
self._extract_var_features_AlvLouWeh2017(sample)
sample.put(
sample.put_vector_list(
"var_features",
self._combine(
sample,
@@ -82,7 +82,7 @@ class FeaturesExtractor:
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks)
self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_")
sample.put(
sample.put_vector_list(
"lp_var_features",
self._combine(
sample,
@@ -103,7 +103,7 @@ class FeaturesExtractor:
],
),
)
sample.put(
sample.put_vector_list(
"lp_constr_features",
self._combine(
sample,
@@ -118,7 +118,7 @@ class FeaturesExtractor:
)
instance_features_user = sample.get("instance_features_user")
assert instance_features_user is not None
sample.put(
sample.put_vector(
"lp_instance_features",
instance_features_user
+ [
@@ -178,7 +178,7 @@ class FeaturesExtractor:
user_features_i = list(user_features_i)
user_features.append(user_features_i)
sample.put("var_categories", categories)
sample.put("var_features_user", user_features)
sample.put_vector_list("var_features_user", user_features)
def _extract_user_features_constrs(
self,
@@ -227,7 +227,7 @@ class FeaturesExtractor:
lazy.append(instance.is_constraint_lazy(cname))
else:
lazy.append(False)
sample.put("constr_features_user", user_features)
sample.put_vector_list("constr_features_user", user_features)
sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories)
@@ -250,7 +250,7 @@ class FeaturesExtractor:
)
constr_lazy = sample.get("constr_lazy")
assert constr_lazy is not None
sample.put("instance_features_user", user_features)
sample.put_vector("instance_features_user", user_features)
sample.put_scalar("static_lazy_count", sum(constr_lazy))
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
@@ -331,7 +331,7 @@ class FeaturesExtractor:
for v in f:
assert isfinite(v), f"non-finite elements detected: {f}"
features.append(f)
sample.put(f"{prefix}var_features_AlvLouWeh2017", features)
sample.put_vector_list(f"{prefix}var_features_AlvLouWeh2017", features)
def _combine(
self,

View File

@@ -3,13 +3,25 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any, Union, List
from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast
import h5py
import numpy as np
from overrides import overrides
Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]]
VectorList = Union[
List[List[bool]],
List[List[str]],
List[List[int]],
List[List[float]],
List[Optional[List[bool]]],
List[Optional[List[str]]],
List[Optional[List[int]]],
List[Optional[List[float]]],
]
class Sample(ABC):
@@ -31,6 +43,14 @@ class Sample(ABC):
def put_vector(self, key: str, value: Vector) -> None:
pass
@abstractmethod
def get_vector_list(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_vector_list(self, key: str, value: VectorList) -> None:
pass
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@@ -65,17 +85,24 @@ class Sample(ABC):
return
assert False, f"Value has unsupported type: {value}"
def _assert_scalar(self, value: Any) -> None:
def _assert_is_scalar(self, value: Any) -> None:
if value is None:
return
if isinstance(value, (str, bool, int, float)):
return
assert False, f"Scalar expected; found instead: {value}"
def _assert_vector(self, value: Any) -> None:
def _assert_is_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
self._assert_scalar(v)
self._assert_is_scalar(v)
def _assert_is_vector_list(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
if v is None:
continue
self._assert_is_vector(v)
class MemorySample(Sample):
@@ -93,20 +120,29 @@ class MemorySample(Sample):
def get_scalar(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def get_vector(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def get_vector_list(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self._assert_is_vector(value)
self.put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
self.put(key, value)
@overrides
@@ -145,23 +181,55 @@ class Hdf5Sample(Sample):
def get_vector(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 1
print(ds.dtype)
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
return ds[:].tolist()
@overrides
def get_vector_list(self, key: str) -> Optional[Any]:
ds = self.file[key]
lens = ds.attrs["lengths"]
if h5py.check_string_dtype(ds.dtype):
padded = ds.asstr()[:].tolist()
else:
padded = ds[:].tolist()
return _crop(padded, lens)
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_scalar(value)
self._assert_is_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self._assert_is_vector(value)
self.put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
if key in self.file:
del self.file[key]
padded, lens = _pad(value)
data = None
for v in value:
if v is None or len(v) == 0:
continue
if isinstance(v[0], str):
data = np.array(padded, dtype="S")
elif isinstance(v[0], bool):
data = np.array(padded, dtype=bool)
else:
data = np.array(padded)
break
assert data is not None
ds = self.file.create_dataset(key, data=data)
ds.attrs["lengths"] = lens
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
@@ -175,3 +243,45 @@ class Hdf5Sample(Sample):
if key in self.file:
del self.file[key]
self.file.create_dataset(key, data=value)
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist)
lens = [len(v) if v is not None else -1 for v in veclist]
maxlen = max(lens)
# Find appropriate constant to pad the vectors
constant: Union[int, float, str, None] = None
for v in veclist:
if v is None or len(v) == 0:
continue
if isinstance(v[0], int):
constant = 0
elif isinstance(v[0], float):
constant = 0.0
elif isinstance(v[0], str):
constant = ""
else:
assert False, f"Unsupported data type: {v[0]}"
assert constant is not None, "veclist must not be completely empty"
# Pad vectors
for (i, vi) in enumerate(veclist):
if vi is None:
vi = veclist[i] = []
assert isinstance(vi, list)
for k in range(len(vi), maxlen):
vi.append(constant)
return veclist, lens
def _crop(veclist: VectorList, lens: List[int]) -> VectorList:
result: VectorList = cast(VectorList, [])
for (i, v) in enumerate(veclist):
if lens[i] < 0:
result.append(None) # type: ignore
else:
assert isinstance(v, list)
result.append(v[: lens[i]])
return result