Remove sample.{get,set}

This commit is contained in:
2021-07-27 09:00:04 -05:00
parent ef9c48d79a
commit 4224586d10
15 changed files with 184 additions and 211 deletions

View File

@@ -52,7 +52,7 @@ class DynamicConstraintsComponent(Component):
cids: Dict[str, List[str]] = {}
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
instance_features = sample.get("instance_features_user")
instance_features = sample.get_vector("instance_features_user")
assert instance_features is not None
for cid in self.known_cids:
# Initialize categories
@@ -81,7 +81,7 @@ class DynamicConstraintsComponent(Component):
cids[category].append(cid)
# Labels
enforced_cids = sample.get(self.attr)
enforced_cids = sample.get_set(self.attr)
if enforced_cids is not None:
if cid in enforced_cids:
y[category] += [[False, True]]
@@ -132,7 +132,7 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return sample.get(self.attr)
return sample.get_set(self.attr)
@overrides
def fit_xy(
@@ -154,7 +154,7 @@ class DynamicConstraintsComponent(Component):
instance: Instance,
sample: Sample,
) -> Dict[str, Dict[str, float]]:
actual = sample.get(self.attr)
actual = sample.get_set(self.attr)
assert actual is not None
pred = set(self.sample_predict(instance, sample))
tp: Dict[str, int] = {}

View File

@@ -78,7 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put("lazy_enforced", set(self.lazy_enforced))
sample.put_set("lazy_enforced", set(self.lazy_enforced))
@overrides
def iteration_cb(

View File

@@ -87,7 +87,7 @@ class UserCutsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put("user_cuts_enforced", set(self.enforced))
sample.put_set("user_cuts_enforced", set(self.enforced))
stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback")

View File

@@ -77,9 +77,9 @@ class ObjectiveValueComponent(Component):
_: Optional[Instance],
sample: Sample,
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
lp_instance_features = sample.get("lp_instance_features")
lp_instance_features = sample.get_vector("lp_instance_features")
if lp_instance_features is None:
lp_instance_features = sample.get("instance_features_user")
lp_instance_features = sample.get_vector("instance_features_user")
assert lp_instance_features is not None
# Features
@@ -90,8 +90,8 @@ class ObjectiveValueComponent(Component):
# Labels
y: Dict[str, List[List[float]]] = {}
mip_lower_bound = sample.get("mip_lower_bound")
mip_upper_bound = sample.get("mip_upper_bound")
mip_lower_bound = sample.get_scalar("mip_lower_bound")
mip_upper_bound = sample.get_scalar("mip_upper_bound")
if mip_lower_bound is not None:
y["Lower bound"] = [[mip_lower_bound]]
if mip_upper_bound is not None:
@@ -116,8 +116,8 @@ class ObjectiveValueComponent(Component):
result: Dict[str, Dict[str, float]] = {}
pred = self.sample_predict(sample)
actual_ub = sample.get("mip_upper_bound")
actual_lb = sample.get("mip_lower_bound")
actual_ub = sample.get_scalar("mip_upper_bound")
actual_lb = sample.get_scalar("mip_lower_bound")
if actual_ub is not None:
result["Upper bound"] = compare(pred["Upper bound"], actual_ub)
if actual_lb is not None:

View File

@@ -95,8 +95,8 @@ class PrimalSolutionComponent(Component):
)
def sample_predict(self, sample: Sample) -> Solution:
var_names = sample.get("var_names")
var_categories = sample.get("var_categories")
var_names = sample.get_vector("var_names")
var_categories = sample.get_vector("var_categories")
assert var_names is not None
assert var_categories is not None
@@ -142,13 +142,13 @@ class PrimalSolutionComponent(Component):
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {}
y: Dict = {}
instance_features = sample.get("instance_features_user")
mip_var_values = sample.get("mip_var_values")
var_features = sample.get("lp_var_features")
var_names = sample.get("var_names")
var_categories = sample.get("var_categories")
instance_features = sample.get_vector("instance_features_user")
mip_var_values = sample.get_vector("mip_var_values")
var_features = sample.get_vector_list("lp_var_features")
var_names = sample.get_vector("var_names")
var_categories = sample.get_vector("var_categories")
if var_features is None:
var_features = sample.get("var_features")
var_features = sample.get_vector_list("var_features")
assert instance_features is not None
assert var_features is not None
assert var_names is not None
@@ -187,8 +187,8 @@ class PrimalSolutionComponent(Component):
_: Optional[Instance],
sample: Sample,
) -> Dict[str, Dict[str, float]]:
mip_var_values = sample.get("mip_var_values")
var_names = sample.get("var_names")
mip_var_values = sample.get_vector("mip_var_values")
var_names = sample.get_vector("var_names")
assert mip_var_values is not None
assert var_names is not None

View File

@@ -61,7 +61,7 @@ class StaticLazyConstraintsComponent(Component):
stats: LearningSolveStats,
sample: Sample,
) -> None:
sample.put("lazy_enforced", self.enforced_cids)
sample.put_set("lazy_enforced", self.enforced_cids)
stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations
@@ -75,7 +75,7 @@ class StaticLazyConstraintsComponent(Component):
sample: Sample,
) -> None:
assert solver.internal_solver is not None
static_lazy_count = sample.get("static_lazy_count")
static_lazy_count = sample.get_scalar("static_lazy_count")
assert static_lazy_count is not None
logger.info("Predicting violated (static) lazy constraints...")
@@ -204,14 +204,14 @@ class StaticLazyConstraintsComponent(Component):
x: Dict[str, List[List[float]]] = {}
y: Dict[str, List[List[float]]] = {}
cids: Dict[str, List[str]] = {}
instance_features = sample.get("instance_features_user")
constr_features = sample.get("lp_constr_features")
constr_names = sample.get("constr_names")
constr_categories = sample.get("constr_categories")
constr_lazy = sample.get("constr_lazy")
lazy_enforced = sample.get("lazy_enforced")
instance_features = sample.get_vector("instance_features_user")
constr_features = sample.get_vector_list("lp_constr_features")
constr_names = sample.get_vector("constr_names")
constr_categories = sample.get_vector("constr_categories")
constr_lazy = sample.get_vector("constr_lazy")
lazy_enforced = sample.get_set("lazy_enforced")
if constr_features is None:
constr_features = sample.get("constr_features_user")
constr_features = sample.get_vector_list("constr_features_user")
assert instance_features is not None
assert constr_features is not None

View File

@@ -39,7 +39,7 @@ class FeaturesExtractor:
sample.put_vector("var_types", variables.types)
sample.put_vector("var_upper_bounds", variables.upper_bounds)
sample.put_vector("constr_names", constraints.names)
sample.put("constr_lhs", constraints.lhs)
# sample.put("constr_lhs", constraints.lhs)
sample.put_vector("constr_rhs", constraints.rhs)
sample.put_vector("constr_senses", constraints.senses)
self._extract_user_features_vars(instance, sample)
@@ -49,13 +49,12 @@ class FeaturesExtractor:
sample.put_vector_list(
"var_features",
self._combine(
sample,
[
"var_features_AlvLouWeh2017",
"var_features_user",
"var_lower_bounds",
"var_obj_coeffs",
"var_upper_bounds",
sample.get_vector_list("var_features_AlvLouWeh2017"),
sample.get_vector_list("var_features_user"),
sample.get_vector("var_lower_bounds"),
sample.get_vector("var_obj_coeffs"),
sample.get_vector("var_upper_bounds"),
],
),
)
@@ -85,45 +84,43 @@ class FeaturesExtractor:
sample.put_vector_list(
"lp_var_features",
self._combine(
sample,
[
"lp_var_features_AlvLouWeh2017",
"lp_var_reduced_costs",
"lp_var_sa_lb_down",
"lp_var_sa_lb_up",
"lp_var_sa_obj_down",
"lp_var_sa_obj_up",
"lp_var_sa_ub_down",
"lp_var_sa_ub_up",
"lp_var_values",
"var_features_user",
"var_lower_bounds",
"var_obj_coeffs",
"var_upper_bounds",
sample.get_vector_list("lp_var_features_AlvLouWeh2017"),
sample.get_vector("lp_var_reduced_costs"),
sample.get_vector("lp_var_sa_lb_down"),
sample.get_vector("lp_var_sa_lb_up"),
sample.get_vector("lp_var_sa_obj_down"),
sample.get_vector("lp_var_sa_obj_up"),
sample.get_vector("lp_var_sa_ub_down"),
sample.get_vector("lp_var_sa_ub_up"),
sample.get_vector("lp_var_values"),
sample.get_vector_list("var_features_user"),
sample.get_vector("var_lower_bounds"),
sample.get_vector("var_obj_coeffs"),
sample.get_vector("var_upper_bounds"),
],
),
)
sample.put_vector_list(
"lp_constr_features",
self._combine(
sample,
[
"constr_features_user",
"lp_constr_dual_values",
"lp_constr_sa_rhs_down",
"lp_constr_sa_rhs_up",
"lp_constr_slacks",
sample.get_vector_list("constr_features_user"),
sample.get_vector("lp_constr_dual_values"),
sample.get_vector("lp_constr_sa_rhs_down"),
sample.get_vector("lp_constr_sa_rhs_up"),
sample.get_vector("lp_constr_slacks"),
],
),
)
instance_features_user = sample.get("instance_features_user")
instance_features_user = sample.get_vector("instance_features_user")
assert instance_features_user is not None
sample.put_vector(
"lp_instance_features",
instance_features_user
+ [
sample.get("lp_value"),
sample.get("lp_wallclock_time"),
sample.get_scalar("lp_value"),
sample.get_scalar("lp_wallclock_time"),
],
)
@@ -146,7 +143,7 @@ class FeaturesExtractor:
user_features: List[Optional[List[float]]] = []
var_features_dict = instance.get_variable_features()
var_categories_dict = instance.get_variable_categories()
var_names = sample.get("var_names")
var_names = sample.get_vector("var_names")
assert var_names is not None
for (i, var_name) in enumerate(var_names):
if var_name not in var_categories_dict:
@@ -177,7 +174,7 @@ class FeaturesExtractor:
)
user_features_i = list(user_features_i)
user_features.append(user_features_i)
sample.put("var_categories", categories)
sample.put_vector("var_categories", categories)
sample.put_vector_list("var_features_user", user_features)
def _extract_user_features_constrs(
@@ -191,7 +188,7 @@ class FeaturesExtractor:
lazy: List[bool] = []
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
constr_names = sample.get("constr_names")
constr_names = sample.get_vector("constr_names")
assert constr_names is not None
for (cidx, cname) in enumerate(constr_names):
@@ -229,7 +226,7 @@ class FeaturesExtractor:
lazy.append(False)
sample.put_vector_list("constr_features_user", user_features)
sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories)
sample.put_vector("constr_categories", categories)
def _extract_user_features_instance(
self,
@@ -248,7 +245,7 @@ class FeaturesExtractor:
f"Instance features must be a list of numbers. "
f"Found {type(v).__name__} instead."
)
constr_lazy = sample.get("constr_lazy")
constr_lazy = sample.get_vector("constr_lazy")
assert constr_lazy is not None
sample.put_vector("instance_features_user", user_features)
sample.put_scalar("static_lazy_count", sum(constr_lazy))
@@ -260,10 +257,10 @@ class FeaturesExtractor:
sample: Sample,
prefix: str = "",
) -> None:
obj_coeffs = sample.get("var_obj_coeffs")
obj_sa_down = sample.get("lp_var_sa_obj_down")
obj_sa_up = sample.get("lp_var_sa_obj_up")
values = sample.get(f"lp_var_values")
obj_coeffs = sample.get_vector("var_obj_coeffs")
obj_sa_down = sample.get_vector("lp_var_sa_obj_down")
obj_sa_up = sample.get_vector("lp_var_sa_obj_up")
values = sample.get_vector(f"lp_var_values")
assert obj_coeffs is not None
pos_obj_coeff_sum = 0.0
@@ -335,12 +332,10 @@ class FeaturesExtractor:
def _combine(
self,
sample: Sample,
attrs: List[str],
items: List,
) -> List[List[float]]:
combined: List[List[float]] = []
for attr in attrs:
series = sample.get(attr)
for series in items:
if series is None:
continue
if len(combined) == 0:

View File

@@ -4,14 +4,22 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
import h5py
import numpy as np
from h5py import Dataset
from overrides import overrides
Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]]
Vector = Union[
None,
List[bool],
List[str],
List[int],
List[float],
List[Optional[str]],
]
VectorList = Union[
List[List[bool]],
List[List[str]],
@@ -51,39 +59,16 @@ class Sample(ABC):
def put_vector_list(self, key: str, value: VectorList) -> None:
pass
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
def get_set(self, key: str) -> Set:
v = self.get_vector(key)
if v:
return set(v)
else:
return set()
@abstractmethod
def put(self, key: str, value: Any) -> None:
"""
Add a new key/value pair to the sample. If the key already exists,
the previous value is silently replaced.
Only the following data types are supported:
- str, bool, int, float
- List[str], List[bool], List[int], List[float]
"""
pass
def _assert_supported(self, value: Any) -> None:
def _is_primitive(v: Any) -> bool:
if isinstance(v, (str, bool, int, float)):
return True
if v is None:
return True
return False
if _is_primitive(value):
return
if isinstance(value, list):
if _is_primitive(value[0]):
return
if isinstance(value[0], list):
if _is_primitive(value[0][0]):
return
assert False, f"Value has unsupported type: {value}"
def put_set(self, key: str, value: Set) -> None:
v = list(value)
self.put_vector(key, v)
def _assert_is_scalar(self, value: Any) -> None:
if value is None:
@@ -118,42 +103,40 @@ class MemorySample(Sample):
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
return self.get(key)
return self._get(key)
@overrides
def get_vector(self, key: str) -> Optional[Any]:
return self.get(key)
return self._get(key)
@overrides
def get_vector_list(self, key: str) -> Optional[Any]:
return self.get(key)
return self._get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value)
self.put(key, value)
self._put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_is_vector(value)
self.put(key, value)
self._put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
self.put(key, value)
self._put(key, value)
@overrides
def get(self, key: str) -> Optional[Any]:
def _get(self, key: str) -> Optional[Any]:
if key in self._data:
return self._data[key]
else:
return None
@overrides
def put(self, key: str, value: Any) -> None:
def _put(self, key: str, value: Any) -> None:
self._data[key] = value
@@ -200,20 +183,18 @@ class Hdf5Sample(Sample):
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_is_scalar(value)
self.put(key, value)
self._put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_is_vector(value)
self.put(key, value)
self._put(key, value)
@overrides
def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value)
if key in self.file:
del self.file[key]
padded, lens = _pad(value)
data = None
for v in value:
@@ -227,22 +208,13 @@ class Hdf5Sample(Sample):
data = np.array(padded)
break
assert data is not None
ds = self.file.create_dataset(key, data=data)
ds = self._put(key, data)
ds.attrs["lengths"] = lens
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
return ds[:].tolist()
@overrides
def put(self, key: str, value: Any) -> None:
def _put(self, key: str, value: Any) -> Dataset:
if key in self.file:
del self.file[key]
self.file.create_dataset(key, data=value)
return self.file.create_dataset(key, data=value)
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:

View File

@@ -89,15 +89,14 @@ class TravelingSalesmanInstance(Instance):
self,
solver: InternalSolver,
model: Any,
) -> List[FrozenSet]:
) -> List[str]:
selected_edges = [e for e in self.edges if model.x[e].value > 0.5]
graph = nx.Graph()
graph.add_edges_from(selected_edges)
components = [frozenset(c) for c in list(nx.connected_components(graph))]
violations = []
for c in components:
for c in list(nx.connected_components(graph)):
if len(c) < self.n_cities:
violations += [c]
violations.append(",".join(map(str, c)))
return violations
@overrides
@@ -105,9 +104,10 @@ class TravelingSalesmanInstance(Instance):
self,
solver: InternalSolver,
model: Any,
component: FrozenSet,
violation: str,
) -> None:
assert isinstance(solver, BasePyomoSolver)
component = [int(v) for v in violation.split(",")]
cut_edges = [
e
for e in self.edges

View File

@@ -80,16 +80,16 @@ class Constraints:
@staticmethod
def from_sample(sample: "Sample") -> "Constraints":
return Constraints(
basis_status=sample.get("lp_constr_basis_status"),
dual_values=sample.get("lp_constr_dual_values"),
lazy=sample.get("constr_lazy"),
lhs=sample.get("constr_lhs"),
names=sample.get("constr_names"),
rhs=sample.get("constr_rhs"),
sa_rhs_down=sample.get("lp_constr_sa_rhs_down"),
sa_rhs_up=sample.get("lp_constr_sa_rhs_up"),
senses=sample.get("constr_senses"),
slacks=sample.get("lp_constr_slacks"),
basis_status=sample.get_vector("lp_constr_basis_status"),
dual_values=sample.get_vector("lp_constr_dual_values"),
lazy=sample.get_vector("constr_lazy"),
# lhs=sample.get_vector("constr_lhs"),
names=sample.get_vector("constr_names"),
rhs=sample.get_vector("constr_rhs"),
sa_rhs_down=sample.get_vector("lp_constr_sa_rhs_down"),
sa_rhs_up=sample.get_vector("lp_constr_sa_rhs_up"),
senses=sample.get_vector("constr_senses"),
slacks=sample.get_vector("lp_constr_slacks"),
)
def __getitem__(self, selected: List[bool]) -> "Constraints":