Split Sample.{get,put} into {get,put}_{scalar,vector}

This commit is contained in:
2021-07-14 10:50:54 -05:00
parent 0a399deeee
commit 8fc7c6ab71
4 changed files with 165 additions and 68 deletions

View File

@@ -33,15 +33,15 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put("var_lower_bounds", variables.lower_bounds)
sample.put("var_names", variables.names)
sample.put("var_obj_coeffs", variables.obj_coeffs)
sample.put("var_types", variables.types)
sample.put("var_upper_bounds", variables.upper_bounds)
sample.put("constr_names", constraints.names)
sample.put_vector("var_lower_bounds", variables.lower_bounds)
sample.put_vector("var_names", variables.names)
sample.put_vector("var_obj_coeffs", variables.obj_coeffs)
sample.put_vector("var_types", variables.types)
sample.put_vector("var_upper_bounds", variables.upper_bounds)
sample.put_vector("constr_names", constraints.names)
sample.put("constr_lhs", constraints.lhs)
sample.put("constr_rhs", constraints.rhs)
sample.put("constr_senses", constraints.senses)
sample.put_vector("constr_rhs", constraints.rhs)
sample.put_vector("constr_senses", constraints.senses)
self._extract_user_features_vars(instance, sample)
self._extract_user_features_constrs(instance, sample)
self._extract_user_features_instance(instance, sample)
@@ -67,20 +67,20 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=False, with_sa=self.with_sa)
constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa)
sample.put("lp_var_basis_status", variables.basis_status)
sample.put("lp_var_reduced_costs", variables.reduced_costs)
sample.put("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put("lp_var_values", variables.values)
sample.put("lp_constr_basis_status", constraints.basis_status)
sample.put("lp_constr_dual_values", constraints.dual_values)
sample.put("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put("lp_constr_slacks", constraints.slacks)
sample.put_vector("lp_var_basis_status", variables.basis_status)
sample.put_vector("lp_var_reduced_costs", variables.reduced_costs)
sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_vector("lp_var_values", variables.values)
sample.put_vector("lp_constr_basis_status", constraints.basis_status)
sample.put_vector("lp_constr_dual_values", constraints.dual_values)
sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks)
self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_")
sample.put(
"lp_var_features",
@@ -134,8 +134,8 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=False, with_sa=False)
constraints = solver.get_constraints(with_static=False, with_sa=False)
sample.put("mip_var_values", variables.values)
sample.put("mip_constr_slacks", constraints.slacks)
sample.put_vector("mip_var_values", variables.values)
sample.put_vector("mip_constr_slacks", constraints.slacks)
def _extract_user_features_vars(
self,
@@ -228,7 +228,7 @@ class FeaturesExtractor:
else:
lazy.append(False)
sample.put("constr_features_user", user_features)
sample.put("constr_lazy", lazy)
sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories)
def _extract_user_features_instance(
@@ -251,7 +251,7 @@ class FeaturesExtractor:
constr_lazy = sample.get("constr_lazy")
assert constr_lazy is not None
sample.put("instance_features_user", user_features)
sample.put("static_lazy_count", sum(constr_lazy))
sample.put_scalar("static_lazy_count", sum(constr_lazy))
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
# approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195.

View File

@@ -3,16 +3,34 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
from typing import Dict, Optional, Any, Union, List
import h5py
import numpy as np
from overrides import overrides
Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]]
class Sample(ABC):
"""Abstract dictionary-like class that stores training data."""
@abstractmethod
def get_scalar(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_scalar(self, key: str, value: Scalar) -> None:
pass
@abstractmethod
def get_vector(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_vector(self, key: str, value: Vector) -> None:
pass
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@@ -33,6 +51,8 @@ class Sample(ABC):
def _is_primitive(v: Any) -> bool:
if isinstance(v, (str, bool, int, float)):
return True
if v is None:
return True
return False
if _is_primitive(value):
@@ -40,8 +60,23 @@ class Sample(ABC):
if isinstance(value, list):
if _is_primitive(value[0]):
return
if isinstance(value[0], list):
if _is_primitive(value[0][0]):
return
assert False, f"Value has unsupported type: {value}"
def _assert_scalar(self, value: Any) -> None:
if value is None:
return
if isinstance(value, (str, bool, int, float)):
return
assert False, f"Scalar expected; found instead: {value}"
def _assert_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
self._assert_scalar(v)
class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory."""
@@ -54,6 +89,26 @@ class MemorySample(Sample):
data = {}
self._data: Dict[str, Any] = data
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def get_vector(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self.put(key, value)
@overrides
def get(self, key: str) -> Optional[Any]:
if key in self._data:
@@ -63,7 +118,6 @@ class MemorySample(Sample):
@overrides
def put(self, key: str, value: Any) -> None:
# self._assert_supported(value)
self._data[key] = value
@@ -78,23 +132,46 @@ class Hdf5Sample(Sample):
def __init__(self, filename: str) -> None:
self.file = h5py.File(filename, "r+")
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 0
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[()]
else:
return ds[()].tolist()
@overrides
def get_vector(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 1
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
return ds[:].tolist()
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self.put(key, value)
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
if h5py.check_string_dtype(ds.dtype):
if ds.shape == ():
return ds.asstr()[()]
else:
return ds.asstr()[:].tolist()
return ds.asstr()[:].tolist()
else:
if ds.shape == ():
return ds[()].tolist()
else:
return ds[:].tolist()
return ds[:].tolist()
@overrides
def put(self, key: str, value: Any) -> None:
self._assert_supported(value)
if key in self.file:
del self.file[key]
self.file.create_dataset(key, data=value)