Remove sample.{get,set}

master
Alinson S. Xavier 4 years ago
parent ef9c48d79a
commit 4224586d10
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -52,7 +52,7 @@ class DynamicConstraintsComponent(Component):
cids: Dict[str, List[str]] = {} cids: Dict[str, List[str]] = {}
constr_categories_dict = instance.get_constraint_categories() constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features() constr_features_dict = instance.get_constraint_features()
instance_features = sample.get("instance_features_user") instance_features = sample.get_vector("instance_features_user")
assert instance_features is not None assert instance_features is not None
for cid in self.known_cids: for cid in self.known_cids:
# Initialize categories # Initialize categories
@ -81,7 +81,7 @@ class DynamicConstraintsComponent(Component):
cids[category].append(cid) cids[category].append(cid)
# Labels # Labels
enforced_cids = sample.get(self.attr) enforced_cids = sample.get_set(self.attr)
if enforced_cids is not None: if enforced_cids is not None:
if cid in enforced_cids: if cid in enforced_cids:
y[category] += [[False, True]] y[category] += [[False, True]]
@ -132,7 +132,7 @@ class DynamicConstraintsComponent(Component):
@overrides @overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any: def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return sample.get(self.attr) return sample.get_set(self.attr)
@overrides @overrides
def fit_xy( def fit_xy(
@ -154,7 +154,7 @@ class DynamicConstraintsComponent(Component):
instance: Instance, instance: Instance,
sample: Sample, sample: Sample,
) -> Dict[str, Dict[str, float]]: ) -> Dict[str, Dict[str, float]]:
actual = sample.get(self.attr) actual = sample.get_set(self.attr)
assert actual is not None assert actual is not None
pred = set(self.sample_predict(instance, sample)) pred = set(self.sample_predict(instance, sample))
tp: Dict[str, int] = {} tp: Dict[str, int] = {}

@ -78,7 +78,7 @@ class DynamicLazyConstraintsComponent(Component):
stats: LearningSolveStats, stats: LearningSolveStats,
sample: Sample, sample: Sample,
) -> None: ) -> None:
sample.put("lazy_enforced", set(self.lazy_enforced)) sample.put_set("lazy_enforced", set(self.lazy_enforced))
@overrides @overrides
def iteration_cb( def iteration_cb(

@ -87,7 +87,7 @@ class UserCutsComponent(Component):
stats: LearningSolveStats, stats: LearningSolveStats,
sample: Sample, sample: Sample,
) -> None: ) -> None:
sample.put("user_cuts_enforced", set(self.enforced)) sample.put_set("user_cuts_enforced", set(self.enforced))
stats["UserCuts: Added in callback"] = self.n_added_in_callback stats["UserCuts: Added in callback"] = self.n_added_in_callback
if self.n_added_in_callback > 0: if self.n_added_in_callback > 0:
logger.info(f"{self.n_added_in_callback} user cuts added in callback") logger.info(f"{self.n_added_in_callback} user cuts added in callback")

@ -77,9 +77,9 @@ class ObjectiveValueComponent(Component):
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]: ) -> Tuple[Dict[str, List[List[float]]], Dict[str, List[List[float]]]]:
lp_instance_features = sample.get("lp_instance_features") lp_instance_features = sample.get_vector("lp_instance_features")
if lp_instance_features is None: if lp_instance_features is None:
lp_instance_features = sample.get("instance_features_user") lp_instance_features = sample.get_vector("instance_features_user")
assert lp_instance_features is not None assert lp_instance_features is not None
# Features # Features
@ -90,8 +90,8 @@ class ObjectiveValueComponent(Component):
# Labels # Labels
y: Dict[str, List[List[float]]] = {} y: Dict[str, List[List[float]]] = {}
mip_lower_bound = sample.get("mip_lower_bound") mip_lower_bound = sample.get_scalar("mip_lower_bound")
mip_upper_bound = sample.get("mip_upper_bound") mip_upper_bound = sample.get_scalar("mip_upper_bound")
if mip_lower_bound is not None: if mip_lower_bound is not None:
y["Lower bound"] = [[mip_lower_bound]] y["Lower bound"] = [[mip_lower_bound]]
if mip_upper_bound is not None: if mip_upper_bound is not None:
@ -116,8 +116,8 @@ class ObjectiveValueComponent(Component):
result: Dict[str, Dict[str, float]] = {} result: Dict[str, Dict[str, float]] = {}
pred = self.sample_predict(sample) pred = self.sample_predict(sample)
actual_ub = sample.get("mip_upper_bound") actual_ub = sample.get_scalar("mip_upper_bound")
actual_lb = sample.get("mip_lower_bound") actual_lb = sample.get_scalar("mip_lower_bound")
if actual_ub is not None: if actual_ub is not None:
result["Upper bound"] = compare(pred["Upper bound"], actual_ub) result["Upper bound"] = compare(pred["Upper bound"], actual_ub)
if actual_lb is not None: if actual_lb is not None:

@ -95,8 +95,8 @@ class PrimalSolutionComponent(Component):
) )
def sample_predict(self, sample: Sample) -> Solution: def sample_predict(self, sample: Sample) -> Solution:
var_names = sample.get("var_names") var_names = sample.get_vector("var_names")
var_categories = sample.get("var_categories") var_categories = sample.get_vector("var_categories")
assert var_names is not None assert var_names is not None
assert var_categories is not None assert var_categories is not None
@ -142,13 +142,13 @@ class PrimalSolutionComponent(Component):
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]: ) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
x: Dict = {} x: Dict = {}
y: Dict = {} y: Dict = {}
instance_features = sample.get("instance_features_user") instance_features = sample.get_vector("instance_features_user")
mip_var_values = sample.get("mip_var_values") mip_var_values = sample.get_vector("mip_var_values")
var_features = sample.get("lp_var_features") var_features = sample.get_vector_list("lp_var_features")
var_names = sample.get("var_names") var_names = sample.get_vector("var_names")
var_categories = sample.get("var_categories") var_categories = sample.get_vector("var_categories")
if var_features is None: if var_features is None:
var_features = sample.get("var_features") var_features = sample.get_vector_list("var_features")
assert instance_features is not None assert instance_features is not None
assert var_features is not None assert var_features is not None
assert var_names is not None assert var_names is not None
@ -187,8 +187,8 @@ class PrimalSolutionComponent(Component):
_: Optional[Instance], _: Optional[Instance],
sample: Sample, sample: Sample,
) -> Dict[str, Dict[str, float]]: ) -> Dict[str, Dict[str, float]]:
mip_var_values = sample.get("mip_var_values") mip_var_values = sample.get_vector("mip_var_values")
var_names = sample.get("var_names") var_names = sample.get_vector("var_names")
assert mip_var_values is not None assert mip_var_values is not None
assert var_names is not None assert var_names is not None

@ -61,7 +61,7 @@ class StaticLazyConstraintsComponent(Component):
stats: LearningSolveStats, stats: LearningSolveStats,
sample: Sample, sample: Sample,
) -> None: ) -> None:
sample.put("lazy_enforced", self.enforced_cids) sample.put_set("lazy_enforced", self.enforced_cids)
stats["LazyStatic: Restored"] = self.n_restored stats["LazyStatic: Restored"] = self.n_restored
stats["LazyStatic: Iterations"] = self.n_iterations stats["LazyStatic: Iterations"] = self.n_iterations
@ -75,7 +75,7 @@ class StaticLazyConstraintsComponent(Component):
sample: Sample, sample: Sample,
) -> None: ) -> None:
assert solver.internal_solver is not None assert solver.internal_solver is not None
static_lazy_count = sample.get("static_lazy_count") static_lazy_count = sample.get_scalar("static_lazy_count")
assert static_lazy_count is not None assert static_lazy_count is not None
logger.info("Predicting violated (static) lazy constraints...") logger.info("Predicting violated (static) lazy constraints...")
@ -204,14 +204,14 @@ class StaticLazyConstraintsComponent(Component):
x: Dict[str, List[List[float]]] = {} x: Dict[str, List[List[float]]] = {}
y: Dict[str, List[List[float]]] = {} y: Dict[str, List[List[float]]] = {}
cids: Dict[str, List[str]] = {} cids: Dict[str, List[str]] = {}
instance_features = sample.get("instance_features_user") instance_features = sample.get_vector("instance_features_user")
constr_features = sample.get("lp_constr_features") constr_features = sample.get_vector_list("lp_constr_features")
constr_names = sample.get("constr_names") constr_names = sample.get_vector("constr_names")
constr_categories = sample.get("constr_categories") constr_categories = sample.get_vector("constr_categories")
constr_lazy = sample.get("constr_lazy") constr_lazy = sample.get_vector("constr_lazy")
lazy_enforced = sample.get("lazy_enforced") lazy_enforced = sample.get_set("lazy_enforced")
if constr_features is None: if constr_features is None:
constr_features = sample.get("constr_features_user") constr_features = sample.get_vector_list("constr_features_user")
assert instance_features is not None assert instance_features is not None
assert constr_features is not None assert constr_features is not None

@ -39,7 +39,7 @@ class FeaturesExtractor:
sample.put_vector("var_types", variables.types) sample.put_vector("var_types", variables.types)
sample.put_vector("var_upper_bounds", variables.upper_bounds) sample.put_vector("var_upper_bounds", variables.upper_bounds)
sample.put_vector("constr_names", constraints.names) sample.put_vector("constr_names", constraints.names)
sample.put("constr_lhs", constraints.lhs) # sample.put("constr_lhs", constraints.lhs)
sample.put_vector("constr_rhs", constraints.rhs) sample.put_vector("constr_rhs", constraints.rhs)
sample.put_vector("constr_senses", constraints.senses) sample.put_vector("constr_senses", constraints.senses)
self._extract_user_features_vars(instance, sample) self._extract_user_features_vars(instance, sample)
@ -49,13 +49,12 @@ class FeaturesExtractor:
sample.put_vector_list( sample.put_vector_list(
"var_features", "var_features",
self._combine( self._combine(
sample,
[ [
"var_features_AlvLouWeh2017", sample.get_vector_list("var_features_AlvLouWeh2017"),
"var_features_user", sample.get_vector_list("var_features_user"),
"var_lower_bounds", sample.get_vector("var_lower_bounds"),
"var_obj_coeffs", sample.get_vector("var_obj_coeffs"),
"var_upper_bounds", sample.get_vector("var_upper_bounds"),
], ],
), ),
) )
@ -85,45 +84,43 @@ class FeaturesExtractor:
sample.put_vector_list( sample.put_vector_list(
"lp_var_features", "lp_var_features",
self._combine( self._combine(
sample,
[ [
"lp_var_features_AlvLouWeh2017", sample.get_vector_list("lp_var_features_AlvLouWeh2017"),
"lp_var_reduced_costs", sample.get_vector("lp_var_reduced_costs"),
"lp_var_sa_lb_down", sample.get_vector("lp_var_sa_lb_down"),
"lp_var_sa_lb_up", sample.get_vector("lp_var_sa_lb_up"),
"lp_var_sa_obj_down", sample.get_vector("lp_var_sa_obj_down"),
"lp_var_sa_obj_up", sample.get_vector("lp_var_sa_obj_up"),
"lp_var_sa_ub_down", sample.get_vector("lp_var_sa_ub_down"),
"lp_var_sa_ub_up", sample.get_vector("lp_var_sa_ub_up"),
"lp_var_values", sample.get_vector("lp_var_values"),
"var_features_user", sample.get_vector_list("var_features_user"),
"var_lower_bounds", sample.get_vector("var_lower_bounds"),
"var_obj_coeffs", sample.get_vector("var_obj_coeffs"),
"var_upper_bounds", sample.get_vector("var_upper_bounds"),
], ],
), ),
) )
sample.put_vector_list( sample.put_vector_list(
"lp_constr_features", "lp_constr_features",
self._combine( self._combine(
sample,
[ [
"constr_features_user", sample.get_vector_list("constr_features_user"),
"lp_constr_dual_values", sample.get_vector("lp_constr_dual_values"),
"lp_constr_sa_rhs_down", sample.get_vector("lp_constr_sa_rhs_down"),
"lp_constr_sa_rhs_up", sample.get_vector("lp_constr_sa_rhs_up"),
"lp_constr_slacks", sample.get_vector("lp_constr_slacks"),
], ],
), ),
) )
instance_features_user = sample.get("instance_features_user") instance_features_user = sample.get_vector("instance_features_user")
assert instance_features_user is not None assert instance_features_user is not None
sample.put_vector( sample.put_vector(
"lp_instance_features", "lp_instance_features",
instance_features_user instance_features_user
+ [ + [
sample.get("lp_value"), sample.get_scalar("lp_value"),
sample.get("lp_wallclock_time"), sample.get_scalar("lp_wallclock_time"),
], ],
) )
@ -146,7 +143,7 @@ class FeaturesExtractor:
user_features: List[Optional[List[float]]] = [] user_features: List[Optional[List[float]]] = []
var_features_dict = instance.get_variable_features() var_features_dict = instance.get_variable_features()
var_categories_dict = instance.get_variable_categories() var_categories_dict = instance.get_variable_categories()
var_names = sample.get("var_names") var_names = sample.get_vector("var_names")
assert var_names is not None assert var_names is not None
for (i, var_name) in enumerate(var_names): for (i, var_name) in enumerate(var_names):
if var_name not in var_categories_dict: if var_name not in var_categories_dict:
@ -177,7 +174,7 @@ class FeaturesExtractor:
) )
user_features_i = list(user_features_i) user_features_i = list(user_features_i)
user_features.append(user_features_i) user_features.append(user_features_i)
sample.put("var_categories", categories) sample.put_vector("var_categories", categories)
sample.put_vector_list("var_features_user", user_features) sample.put_vector_list("var_features_user", user_features)
def _extract_user_features_constrs( def _extract_user_features_constrs(
@ -191,7 +188,7 @@ class FeaturesExtractor:
lazy: List[bool] = [] lazy: List[bool] = []
constr_categories_dict = instance.get_constraint_categories() constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features() constr_features_dict = instance.get_constraint_features()
constr_names = sample.get("constr_names") constr_names = sample.get_vector("constr_names")
assert constr_names is not None assert constr_names is not None
for (cidx, cname) in enumerate(constr_names): for (cidx, cname) in enumerate(constr_names):
@ -229,7 +226,7 @@ class FeaturesExtractor:
lazy.append(False) lazy.append(False)
sample.put_vector_list("constr_features_user", user_features) sample.put_vector_list("constr_features_user", user_features)
sample.put_vector("constr_lazy", lazy) sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories) sample.put_vector("constr_categories", categories)
def _extract_user_features_instance( def _extract_user_features_instance(
self, self,
@ -248,7 +245,7 @@ class FeaturesExtractor:
f"Instance features must be a list of numbers. " f"Instance features must be a list of numbers. "
f"Found {type(v).__name__} instead." f"Found {type(v).__name__} instead."
) )
constr_lazy = sample.get("constr_lazy") constr_lazy = sample.get_vector("constr_lazy")
assert constr_lazy is not None assert constr_lazy is not None
sample.put_vector("instance_features_user", user_features) sample.put_vector("instance_features_user", user_features)
sample.put_scalar("static_lazy_count", sum(constr_lazy)) sample.put_scalar("static_lazy_count", sum(constr_lazy))
@ -260,10 +257,10 @@ class FeaturesExtractor:
sample: Sample, sample: Sample,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
obj_coeffs = sample.get("var_obj_coeffs") obj_coeffs = sample.get_vector("var_obj_coeffs")
obj_sa_down = sample.get("lp_var_sa_obj_down") obj_sa_down = sample.get_vector("lp_var_sa_obj_down")
obj_sa_up = sample.get("lp_var_sa_obj_up") obj_sa_up = sample.get_vector("lp_var_sa_obj_up")
values = sample.get(f"lp_var_values") values = sample.get_vector(f"lp_var_values")
assert obj_coeffs is not None assert obj_coeffs is not None
pos_obj_coeff_sum = 0.0 pos_obj_coeff_sum = 0.0
@ -335,12 +332,10 @@ class FeaturesExtractor:
def _combine( def _combine(
self, self,
sample: Sample, items: List,
attrs: List[str],
) -> List[List[float]]: ) -> List[List[float]]:
combined: List[List[float]] = [] combined: List[List[float]] = []
for attr in attrs: for series in items:
series = sample.get(attr)
if series is None: if series is None:
continue continue
if len(combined) == 0: if len(combined) == 0:

@ -4,14 +4,22 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
import h5py import h5py
import numpy as np import numpy as np
from h5py import Dataset
from overrides import overrides from overrides import overrides
Scalar = Union[None, bool, str, int, float] Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]] Vector = Union[
None,
List[bool],
List[str],
List[int],
List[float],
List[Optional[str]],
]
VectorList = Union[ VectorList = Union[
List[List[bool]], List[List[bool]],
List[List[str]], List[List[str]],
@ -51,39 +59,16 @@ class Sample(ABC):
def put_vector_list(self, key: str, value: VectorList) -> None: def put_vector_list(self, key: str, value: VectorList) -> None:
pass pass
@abstractmethod def get_set(self, key: str) -> Set:
def get(self, key: str) -> Optional[Any]: v = self.get_vector(key)
pass if v:
return set(v)
@abstractmethod else:
def put(self, key: str, value: Any) -> None: return set()
"""
Add a new key/value pair to the sample. If the key already exists,
the previous value is silently replaced.
Only the following data types are supported:
- str, bool, int, float
- List[str], List[bool], List[int], List[float]
"""
pass
def _assert_supported(self, value: Any) -> None: def put_set(self, key: str, value: Set) -> None:
def _is_primitive(v: Any) -> bool: v = list(value)
if isinstance(v, (str, bool, int, float)): self.put_vector(key, v)
return True
if v is None:
return True
return False
if _is_primitive(value):
return
if isinstance(value, list):
if _is_primitive(value[0]):
return
if isinstance(value[0], list):
if _is_primitive(value[0][0]):
return
assert False, f"Value has unsupported type: {value}"
def _assert_is_scalar(self, value: Any) -> None: def _assert_is_scalar(self, value: Any) -> None:
if value is None: if value is None:
@ -118,42 +103,40 @@ class MemorySample(Sample):
@overrides @overrides
def get_scalar(self, key: str) -> Optional[Any]: def get_scalar(self, key: str) -> Optional[Any]:
return self.get(key) return self._get(key)
@overrides @overrides
def get_vector(self, key: str) -> Optional[Any]: def get_vector(self, key: str) -> Optional[Any]:
return self.get(key) return self._get(key)
@overrides @overrides
def get_vector_list(self, key: str) -> Optional[Any]: def get_vector_list(self, key: str) -> Optional[Any]:
return self.get(key) return self._get(key)
@overrides @overrides
def put_scalar(self, key: str, value: Scalar) -> None: def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_is_scalar(value) self._assert_is_scalar(value)
self.put(key, value) self._put(key, value)
@overrides @overrides
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
if value is None: if value is None:
return return
self._assert_is_vector(value) self._assert_is_vector(value)
self.put(key, value) self._put(key, value)
@overrides @overrides
def put_vector_list(self, key: str, value: VectorList) -> None: def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value) self._assert_is_vector_list(value)
self.put(key, value) self._put(key, value)
@overrides def _get(self, key: str) -> Optional[Any]:
def get(self, key: str) -> Optional[Any]:
if key in self._data: if key in self._data:
return self._data[key] return self._data[key]
else: else:
return None return None
@overrides def _put(self, key: str, value: Any) -> None:
def put(self, key: str, value: Any) -> None:
self._data[key] = value self._data[key] = value
@ -200,20 +183,18 @@ class Hdf5Sample(Sample):
@overrides @overrides
def put_scalar(self, key: str, value: Any) -> None: def put_scalar(self, key: str, value: Any) -> None:
self._assert_is_scalar(value) self._assert_is_scalar(value)
self.put(key, value) self._put(key, value)
@overrides @overrides
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
if value is None: if value is None:
return return
self._assert_is_vector(value) self._assert_is_vector(value)
self.put(key, value) self._put(key, value)
@overrides @overrides
def put_vector_list(self, key: str, value: VectorList) -> None: def put_vector_list(self, key: str, value: VectorList) -> None:
self._assert_is_vector_list(value) self._assert_is_vector_list(value)
if key in self.file:
del self.file[key]
padded, lens = _pad(value) padded, lens = _pad(value)
data = None data = None
for v in value: for v in value:
@ -227,22 +208,13 @@ class Hdf5Sample(Sample):
data = np.array(padded) data = np.array(padded)
break break
assert data is not None assert data is not None
ds = self.file.create_dataset(key, data=data) ds = self._put(key, data)
ds.attrs["lengths"] = lens ds.attrs["lengths"] = lens
@overrides def _put(self, key: str, value: Any) -> Dataset:
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
return ds[:].tolist()
@overrides
def put(self, key: str, value: Any) -> None:
if key in self.file: if key in self.file:
del self.file[key] del self.file[key]
self.file.create_dataset(key, data=value) return self.file.create_dataset(key, data=value)
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:

@ -89,15 +89,14 @@ class TravelingSalesmanInstance(Instance):
self, self,
solver: InternalSolver, solver: InternalSolver,
model: Any, model: Any,
) -> List[FrozenSet]: ) -> List[str]:
selected_edges = [e for e in self.edges if model.x[e].value > 0.5] selected_edges = [e for e in self.edges if model.x[e].value > 0.5]
graph = nx.Graph() graph = nx.Graph()
graph.add_edges_from(selected_edges) graph.add_edges_from(selected_edges)
components = [frozenset(c) for c in list(nx.connected_components(graph))]
violations = [] violations = []
for c in components: for c in list(nx.connected_components(graph)):
if len(c) < self.n_cities: if len(c) < self.n_cities:
violations += [c] violations.append(",".join(map(str, c)))
return violations return violations
@overrides @overrides
@ -105,9 +104,10 @@ class TravelingSalesmanInstance(Instance):
self, self,
solver: InternalSolver, solver: InternalSolver,
model: Any, model: Any,
component: FrozenSet, violation: str,
) -> None: ) -> None:
assert isinstance(solver, BasePyomoSolver) assert isinstance(solver, BasePyomoSolver)
component = [int(v) for v in violation.split(",")]
cut_edges = [ cut_edges = [
e e
for e in self.edges for e in self.edges

@ -80,16 +80,16 @@ class Constraints:
@staticmethod @staticmethod
def from_sample(sample: "Sample") -> "Constraints": def from_sample(sample: "Sample") -> "Constraints":
return Constraints( return Constraints(
basis_status=sample.get("lp_constr_basis_status"), basis_status=sample.get_vector("lp_constr_basis_status"),
dual_values=sample.get("lp_constr_dual_values"), dual_values=sample.get_vector("lp_constr_dual_values"),
lazy=sample.get("constr_lazy"), lazy=sample.get_vector("constr_lazy"),
lhs=sample.get("constr_lhs"), # lhs=sample.get_vector("constr_lhs"),
names=sample.get("constr_names"), names=sample.get_vector("constr_names"),
rhs=sample.get("constr_rhs"), rhs=sample.get_vector("constr_rhs"),
sa_rhs_down=sample.get("lp_constr_sa_rhs_down"), sa_rhs_down=sample.get_vector("lp_constr_sa_rhs_down"),
sa_rhs_up=sample.get("lp_constr_sa_rhs_up"), sa_rhs_up=sample.get_vector("lp_constr_sa_rhs_up"),
senses=sample.get("constr_senses"), senses=sample.get_vector("constr_senses"),
slacks=sample.get("lp_constr_slacks"), slacks=sample.get_vector("lp_constr_slacks"),
) )
def __getitem__(self, selected: List[bool]) -> "Constraints": def __getitem__(self, selected: List[bool]) -> "Constraints":

@ -81,7 +81,7 @@ def test_usage(
) -> None: ) -> None:
stats_before = solver.solve(stab_instance) stats_before = solver.solve(stab_instance)
sample = stab_instance.get_samples()[0] sample = stab_instance.get_samples()[0]
user_cuts_enforced = sample.get("user_cuts_enforced") user_cuts_enforced = sample.get_set("user_cuts_enforced")
assert user_cuts_enforced is not None assert user_cuts_enforced is not None
assert len(user_cuts_enforced) > 0 assert len(user_cuts_enforced) > 0
assert stats_before["UserCuts: Added ahead-of-time"] == 0 assert stats_before["UserCuts: Added ahead-of-time"] == 0

@ -93,7 +93,7 @@ def test_usage_with_solver(instance: Instance) -> None:
stats: LearningSolveStats = {} stats: LearningSolveStats = {}
sample = instance.get_samples()[0] sample = instance.get_samples()[0]
assert sample.get("lazy_enforced") is not None assert sample.get_set("lazy_enforced") is not None
# LearningSolver calls before_solve_mip # LearningSolver calls before_solve_mip
component.before_solve_mip( component.before_solve_mip(
@ -142,7 +142,7 @@ def test_usage_with_solver(instance: Instance) -> None:
) )
# Should update training sample # Should update training sample
assert sample.get("lazy_enforced") == {"c1", "c2", "c3", "c4"} assert sample.get_set("lazy_enforced") == {"c1", "c2", "c3", "c4"}
# #
# Should update stats # Should update stats
assert stats["LazyStatic: Removed"] == 1 assert stats["LazyStatic: Removed"] == 1

@ -24,21 +24,23 @@ def test_knapsack() -> None:
# after-load # after-load
# ------------------------------------------------------- # -------------------------------------------------------
extractor.extract_after_load_features(instance, solver, sample) extractor.extract_after_load_features(instance, solver, sample)
assert_equals(sample.get("var_names"), ["x[0]", "x[1]", "x[2]", "x[3]", "z"]) assert_equals(sample.get_vector("var_names"), ["x[0]", "x[1]", "x[2]", "x[3]", "z"])
assert_equals(sample.get("var_lower_bounds"), [0.0, 0.0, 0.0, 0.0, 0.0]) assert_equals(sample.get_vector("var_lower_bounds"), [0.0, 0.0, 0.0, 0.0, 0.0])
assert_equals(sample.get("var_obj_coeffs"), [505.0, 352.0, 458.0, 220.0, 0.0])
assert_equals(sample.get("var_types"), ["B", "B", "B", "B", "C"])
assert_equals(sample.get("var_upper_bounds"), [1.0, 1.0, 1.0, 1.0, 67.0])
assert_equals( assert_equals(
sample.get("var_categories"), sample.get_vector("var_obj_coeffs"), [505.0, 352.0, 458.0, 220.0, 0.0]
)
assert_equals(sample.get_vector("var_types"), ["B", "B", "B", "B", "C"])
assert_equals(sample.get_vector("var_upper_bounds"), [1.0, 1.0, 1.0, 1.0, 67.0])
assert_equals(
sample.get_vector("var_categories"),
["default", "default", "default", "default", None], ["default", "default", "default", "default", None],
) )
assert_equals( assert_equals(
sample.get("var_features_user"), sample.get_vector_list("var_features_user"),
[[23.0, 505.0], [26.0, 352.0], [20.0, 458.0], [18.0, 220.0], None], [[23.0, 505.0], [26.0, 352.0], [20.0, 458.0], [18.0, 220.0], None],
) )
assert_equals( assert_equals(
sample.get("var_features_AlvLouWeh2017"), sample.get_vector_list("var_features_AlvLouWeh2017"),
[ [
[1.0, 0.32899, 0.0], [1.0, 0.32899, 0.0],
[1.0, 0.229316, 0.0], [1.0, 0.229316, 0.0],
@ -47,61 +49,63 @@ def test_knapsack() -> None:
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
], ],
) )
assert sample.get("var_features") is not None assert sample.get_vector_list("var_features") is not None
assert_equals(sample.get("constr_names"), ["eq_capacity"]) assert_equals(sample.get_vector("constr_names"), ["eq_capacity"])
assert_equals( # assert_equals(
sample.get("constr_lhs"), # sample.get_vector("constr_lhs"),
[ # [
[ # [
("x[0]", 23.0), # ("x[0]", 23.0),
("x[1]", 26.0), # ("x[1]", 26.0),
("x[2]", 20.0), # ("x[2]", 20.0),
("x[3]", 18.0), # ("x[3]", 18.0),
("z", -1.0), # ("z", -1.0),
], # ],
], # ],
) # )
assert_equals(sample.get("constr_rhs"), [0.0]) assert_equals(sample.get_vector("constr_rhs"), [0.0])
assert_equals(sample.get("constr_senses"), ["="]) assert_equals(sample.get_vector("constr_senses"), ["="])
assert_equals(sample.get("constr_features_user"), [None]) assert_equals(sample.get_vector("constr_features_user"), [None])
assert_equals(sample.get("constr_categories"), ["eq_capacity"]) assert_equals(sample.get_vector("constr_categories"), ["eq_capacity"])
assert_equals(sample.get("constr_lazy"), [False]) assert_equals(sample.get_vector("constr_lazy"), [False])
assert_equals(sample.get("instance_features_user"), [67.0, 21.75]) assert_equals(sample.get_vector("instance_features_user"), [67.0, 21.75])
assert_equals(sample.get("static_lazy_count"), 0) assert_equals(sample.get_scalar("static_lazy_count"), 0)
# after-lp # after-lp
# ------------------------------------------------------- # -------------------------------------------------------
solver.solve_lp() solver.solve_lp()
extractor.extract_after_lp_features(solver, sample) extractor.extract_after_lp_features(solver, sample)
assert_equals( assert_equals(
sample.get("lp_var_basis_status"), sample.get_vector("lp_var_basis_status"),
["U", "B", "U", "L", "U"], ["U", "B", "U", "L", "U"],
) )
assert_equals( assert_equals(
sample.get("lp_var_reduced_costs"), sample.get_vector("lp_var_reduced_costs"),
[193.615385, 0.0, 187.230769, -23.692308, 13.538462], [193.615385, 0.0, 187.230769, -23.692308, 13.538462],
) )
assert_equals( assert_equals(
sample.get("lp_var_sa_lb_down"), sample.get_vector("lp_var_sa_lb_down"),
[-inf, -inf, -inf, -0.111111, -inf], [-inf, -inf, -inf, -0.111111, -inf],
) )
assert_equals( assert_equals(
sample.get("lp_var_sa_lb_up"), sample.get_vector("lp_var_sa_lb_up"),
[1.0, 0.923077, 1.0, 1.0, 67.0], [1.0, 0.923077, 1.0, 1.0, 67.0],
) )
assert_equals( assert_equals(
sample.get("lp_var_sa_obj_down"), sample.get_vector("lp_var_sa_obj_down"),
[311.384615, 317.777778, 270.769231, -inf, -13.538462], [311.384615, 317.777778, 270.769231, -inf, -13.538462],
) )
assert_equals( assert_equals(
sample.get("lp_var_sa_obj_up"), sample.get_vector("lp_var_sa_obj_up"),
[inf, 570.869565, inf, 243.692308, inf], [inf, 570.869565, inf, 243.692308, inf],
) )
assert_equals(sample.get("lp_var_sa_ub_down"), [0.913043, 0.923077, 0.9, 0.0, 43.0])
assert_equals(sample.get("lp_var_sa_ub_up"), [2.043478, inf, 2.2, inf, 69.0])
assert_equals(sample.get("lp_var_values"), [1.0, 0.923077, 1.0, 0.0, 67.0])
assert_equals( assert_equals(
sample.get("lp_var_features_AlvLouWeh2017"), sample.get_vector("lp_var_sa_ub_down"), [0.913043, 0.923077, 0.9, 0.0, 43.0]
)
assert_equals(sample.get_vector("lp_var_sa_ub_up"), [2.043478, inf, 2.2, inf, 69.0])
assert_equals(sample.get_vector("lp_var_values"), [1.0, 0.923077, 1.0, 0.0, 67.0])
assert_equals(
sample.get_vector_list("lp_var_features_AlvLouWeh2017"),
[ [
[1.0, 0.32899, 0.0, 0.0, 1.0, 1.0, 5.265874, 46.051702], [1.0, 0.32899, 0.0, 0.0, 1.0, 1.0, 5.265874, 46.051702],
[1.0, 0.229316, 0.0, 0.076923, 1.0, 1.0, 3.532875, 5.388476], [1.0, 0.229316, 0.0, 0.076923, 1.0, 1.0, 3.532875, 5.388476],
@ -110,19 +114,19 @@ def test_knapsack() -> None:
[0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0],
], ],
) )
assert sample.get("lp_var_features") is not None assert sample.get_vector_list("lp_var_features") is not None
assert_equals(sample.get("lp_constr_basis_status"), ["N"]) assert_equals(sample.get_vector("lp_constr_basis_status"), ["N"])
assert_equals(sample.get("lp_constr_dual_values"), [13.538462]) assert_equals(sample.get_vector("lp_constr_dual_values"), [13.538462])
assert_equals(sample.get("lp_constr_sa_rhs_down"), [-24.0]) assert_equals(sample.get_vector("lp_constr_sa_rhs_down"), [-24.0])
assert_equals(sample.get("lp_constr_sa_rhs_up"), [2.0]) assert_equals(sample.get_vector("lp_constr_sa_rhs_up"), [2.0])
assert_equals(sample.get("lp_constr_slacks"), [0.0]) assert_equals(sample.get_vector("lp_constr_slacks"), [0.0])
# after-mip # after-mip
# ------------------------------------------------------- # -------------------------------------------------------
solver.solve() solver.solve()
extractor.extract_after_mip_features(solver, sample) extractor.extract_after_mip_features(solver, sample)
assert_equals(sample.get("mip_var_values"), [1.0, 0.0, 1.0, 1.0, 61.0]) assert_equals(sample.get_vector("mip_var_values"), [1.0, 0.0, 1.0, 1.0, 61.0])
assert_equals(sample.get("mip_constr_slacks"), [0.0]) assert_equals(sample.get_vector("mip_constr_slacks"), [0.0])
def test_constraint_getindex() -> None: def test_constraint_getindex() -> None:

@ -41,9 +41,9 @@ def test_instance() -> None:
solver.solve(instance) solver.solve(instance)
assert len(instance.get_samples()) == 1 assert len(instance.get_samples()) == 1
sample = instance.get_samples()[0] sample = instance.get_samples()[0]
assert sample.get("mip_var_values") == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0] assert sample.get_vector("mip_var_values") == [1.0, 0.0, 1.0, 1.0, 0.0, 1.0]
assert sample.get("mip_lower_bound") == 4.0 assert sample.get_scalar("mip_lower_bound") == 4.0
assert sample.get("mip_upper_bound") == 4.0 assert sample.get_scalar("mip_upper_bound") == 4.0
def test_subtour() -> None: def test_subtour() -> None:
@ -65,10 +65,10 @@ def test_subtour() -> None:
samples = instance.get_samples() samples = instance.get_samples()
assert len(samples) == 1 assert len(samples) == 1
sample = samples[0] sample = samples[0]
lazy_enforced = sample.get("lazy_enforced") lazy_enforced = sample.get_set("lazy_enforced")
assert lazy_enforced is not None assert lazy_enforced is not None
assert len(lazy_enforced) > 0 assert len(lazy_enforced) > 0
assert sample.get("mip_var_values") == [ assert sample.get_vector("mip_var_values") == [
1.0, 1.0,
0.0, 0.0,
0.0, 0.0,

@ -38,16 +38,18 @@ def test_learning_solver(
assert len(instance.get_samples()) > 0 assert len(instance.get_samples()) > 0
sample = instance.get_samples()[0] sample = instance.get_samples()[0]
assert sample.get("mip_var_values") == [1.0, 0.0, 1.0, 1.0, 61.0] assert sample.get_vector("mip_var_values") == [1.0, 0.0, 1.0, 1.0, 61.0]
assert sample.get("mip_lower_bound") == 1183.0 assert sample.get_scalar("mip_lower_bound") == 1183.0
assert sample.get("mip_upper_bound") == 1183.0 assert sample.get_scalar("mip_upper_bound") == 1183.0
mip_log = sample.get("mip_log") mip_log = sample.get_scalar("mip_log")
assert mip_log is not None assert mip_log is not None
assert len(mip_log) > 100 assert len(mip_log) > 100
assert_equals(sample.get("lp_var_values"), [1.0, 0.923077, 1.0, 0.0, 67.0]) assert_equals(
assert_equals(sample.get("lp_value"), 1287.923077) sample.get_vector("lp_var_values"), [1.0, 0.923077, 1.0, 0.0, 67.0]
lp_log = sample.get("lp_log") )
assert_equals(sample.get_scalar("lp_value"), 1287.923077)
lp_log = sample.get_scalar("lp_log")
assert lp_log is not None assert lp_log is not None
assert len(lp_log) > 100 assert len(lp_log) > 100

Loading…
Cancel
Save