Implement {get,put}_array; make other methods deprecated

This commit is contained in:
2021-08-08 06:52:24 -05:00
parent 0a32586bf8
commit f69067aafd
4 changed files with 98 additions and 167 deletions

View File

@@ -33,14 +33,14 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put_vector("static_var_lower_bounds", variables.lower_bounds)
sample.put_array("static_var_lower_bounds", variables.lower_bounds)
sample.put_vector("static_var_names", variables.names)
sample.put_vector("static_var_obj_coeffs", variables.obj_coeffs)
sample.put_array("static_var_obj_coeffs", variables.obj_coeffs)
sample.put_vector("static_var_types", variables.types)
sample.put_vector("static_var_upper_bounds", variables.upper_bounds)
sample.put_array("static_var_upper_bounds", variables.upper_bounds)
sample.put_vector("static_constr_names", constraints.names)
# sample.put("static_constr_lhs", constraints.lhs)
sample.put_vector("static_constr_rhs", constraints.rhs)
sample.put_array("static_constr_rhs", constraints.rhs)
sample.put_vector("static_constr_senses", constraints.senses)
vars_features_user, var_categories = self._extract_user_features_vars(
instance, sample
@@ -55,9 +55,9 @@ class FeaturesExtractor:
[
alw17,
vars_features_user,
sample.get_vector("static_var_lower_bounds"),
sample.get_vector("static_var_obj_coeffs"),
sample.get_vector("static_var_upper_bounds"),
sample.get_array("static_var_lower_bounds"),
sample.get_array("static_var_obj_coeffs"),
sample.get_array("static_var_upper_bounds"),
],
),
)
@@ -70,33 +70,33 @@ class FeaturesExtractor:
variables = solver.get_variables(with_static=False, with_sa=self.with_sa)
constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa)
sample.put_vector("lp_var_basis_status", variables.basis_status)
sample.put_vector("lp_var_reduced_costs", variables.reduced_costs)
sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_vector("lp_var_values", variables.values)
sample.put_array("lp_var_reduced_costs", variables.reduced_costs)
sample.put_array("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_array("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_array("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_array("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_array("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_array("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_array("lp_var_values", variables.values)
sample.put_vector("lp_constr_basis_status", constraints.basis_status)
sample.put_vector("lp_constr_dual_values", constraints.dual_values)
sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks)
sample.put_array("lp_constr_dual_values", constraints.dual_values)
sample.put_array("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_array("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_array("lp_constr_slacks", constraints.slacks)
alw17 = self._extract_var_features_AlvLouWeh2017(sample)
sample.put_vector_list(
"lp_var_features",
self._combine(
[
alw17,
sample.get_vector("lp_var_reduced_costs"),
sample.get_vector("lp_var_sa_lb_down"),
sample.get_vector("lp_var_sa_lb_up"),
sample.get_vector("lp_var_sa_obj_down"),
sample.get_vector("lp_var_sa_obj_up"),
sample.get_vector("lp_var_sa_ub_down"),
sample.get_vector("lp_var_sa_ub_up"),
sample.get_vector("lp_var_values"),
sample.get_array("lp_var_reduced_costs"),
sample.get_array("lp_var_sa_lb_down"),
sample.get_array("lp_var_sa_lb_up"),
sample.get_array("lp_var_sa_obj_down"),
sample.get_array("lp_var_sa_obj_up"),
sample.get_array("lp_var_sa_ub_down"),
sample.get_array("lp_var_sa_ub_up"),
sample.get_array("lp_var_values"),
sample.get_vector_list("static_var_features"),
],
),
@@ -106,10 +106,10 @@ class FeaturesExtractor:
self._combine(
[
sample.get_vector_list("static_constr_features"),
sample.get_vector("lp_constr_dual_values"),
sample.get_vector("lp_constr_sa_rhs_down"),
sample.get_vector("lp_constr_sa_rhs_up"),
sample.get_vector("lp_constr_slacks"),
sample.get_array("lp_constr_dual_values"),
sample.get_array("lp_constr_sa_rhs_down"),
sample.get_array("lp_constr_sa_rhs_up"),
sample.get_array("lp_constr_slacks"),
],
),
)
@@ -131,8 +131,8 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=False, with_sa=False)
constraints = solver.get_constraints(with_static=False, with_sa=False)
sample.put_vector("mip_var_values", variables.values)
sample.put_vector("mip_constr_slacks", constraints.slacks)
sample.put_array("mip_var_values", variables.values)
sample.put_array("mip_constr_slacks", constraints.slacks)
def _extract_user_features_vars(
self,

View File

@@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
@@ -39,11 +39,12 @@ class Sample(ABC):
@abstractmethod
def get_bytes(self, key: str) -> Optional[Bytes]:
pass
warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod
def put_bytes(self, key: str, value: Bytes) -> None:
pass
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def get_scalar(self, key: str) -> Optional[Any]:
@@ -55,18 +56,28 @@ class Sample(ABC):
@abstractmethod
def get_vector(self, key: str) -> Optional[Any]:
pass
warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod
def put_vector(self, key: str, value: Vector) -> None:
pass
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def get_vector_list(self, key: str) -> Optional[Any]:
pass
warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod
def put_vector_list(self, key: str, value: VectorList) -> None:
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
pass
@abstractmethod
def get_array(self, key: str) -> Optional[np.ndarray]:
pass
def get_set(self, key: str) -> Set:
@@ -103,6 +114,10 @@ class Sample(ABC):
continue
self._assert_is_vector(v)
def _assert_supported(self, value: np.ndarray) -> None:
assert isinstance(value, np.ndarray)
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory."""
@@ -171,6 +186,17 @@ class MemorySample(Sample):
def _put(self, key: str, value: Any) -> None:
self._data[key] = value
@overrides
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
if value is None:
return
self._assert_supported(value)
self._put(key, value)
@overrides
def get_array(self, key: str) -> Optional[np.ndarray]:
return cast(Optional[np.ndarray], self._get(key))
class Hdf5Sample(Sample):
"""
@@ -310,6 +336,21 @@ class Hdf5Sample(Sample):
ds = self.file.create_dataset(key, data=value)
return ds
@overrides
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
if value is None:
return
self._assert_supported(value)
if key in self.file:
del self.file[key]
return self.file.create_dataset(key, data=value, compression="gzip")
@overrides
def get_array(self, key: str) -> Optional[np.ndarray]:
if key not in self.file:
return None
return self.file[key][:]
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist)

View File

@@ -332,7 +332,7 @@ class GurobiSolver(InternalSolver):
obj_coeffs = self._var_obj_coeffs
if self._has_lp_solution:
reduced_costs = model.getAttr("rc", self._gp_vars)
reduced_costs = np.array(model.getAttr("rc", self._gp_vars), dtype=float)
basis_status = list(
map(
_parse_gurobi_vbasis,