Implement {get,put}_array; make other methods deprecated

master
Alinson S. Xavier 4 years ago
parent 0a32586bf8
commit f69067aafd
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -33,14 +33,14 @@ class FeaturesExtractor:
) -> None: ) -> None:
variables = solver.get_variables(with_static=True) variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs) constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put_vector("static_var_lower_bounds", variables.lower_bounds) sample.put_array("static_var_lower_bounds", variables.lower_bounds)
sample.put_vector("static_var_names", variables.names) sample.put_vector("static_var_names", variables.names)
sample.put_vector("static_var_obj_coeffs", variables.obj_coeffs) sample.put_array("static_var_obj_coeffs", variables.obj_coeffs)
sample.put_vector("static_var_types", variables.types) sample.put_vector("static_var_types", variables.types)
sample.put_vector("static_var_upper_bounds", variables.upper_bounds) sample.put_array("static_var_upper_bounds", variables.upper_bounds)
sample.put_vector("static_constr_names", constraints.names) sample.put_vector("static_constr_names", constraints.names)
# sample.put("static_constr_lhs", constraints.lhs) # sample.put("static_constr_lhs", constraints.lhs)
sample.put_vector("static_constr_rhs", constraints.rhs) sample.put_array("static_constr_rhs", constraints.rhs)
sample.put_vector("static_constr_senses", constraints.senses) sample.put_vector("static_constr_senses", constraints.senses)
vars_features_user, var_categories = self._extract_user_features_vars( vars_features_user, var_categories = self._extract_user_features_vars(
instance, sample instance, sample
@ -55,9 +55,9 @@ class FeaturesExtractor:
[ [
alw17, alw17,
vars_features_user, vars_features_user,
sample.get_vector("static_var_lower_bounds"), sample.get_array("static_var_lower_bounds"),
sample.get_vector("static_var_obj_coeffs"), sample.get_array("static_var_obj_coeffs"),
sample.get_vector("static_var_upper_bounds"), sample.get_array("static_var_upper_bounds"),
], ],
), ),
) )
@ -70,33 +70,33 @@ class FeaturesExtractor:
variables = solver.get_variables(with_static=False, with_sa=self.with_sa) variables = solver.get_variables(with_static=False, with_sa=self.with_sa)
constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa) constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa)
sample.put_vector("lp_var_basis_status", variables.basis_status) sample.put_vector("lp_var_basis_status", variables.basis_status)
sample.put_vector("lp_var_reduced_costs", variables.reduced_costs) sample.put_array("lp_var_reduced_costs", variables.reduced_costs)
sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down) sample.put_array("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up) sample.put_array("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down) sample.put_array("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up) sample.put_array("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down) sample.put_array("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up) sample.put_array("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_vector("lp_var_values", variables.values) sample.put_array("lp_var_values", variables.values)
sample.put_vector("lp_constr_basis_status", constraints.basis_status) sample.put_vector("lp_constr_basis_status", constraints.basis_status)
sample.put_vector("lp_constr_dual_values", constraints.dual_values) sample.put_array("lp_constr_dual_values", constraints.dual_values)
sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down) sample.put_array("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up) sample.put_array("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks) sample.put_array("lp_constr_slacks", constraints.slacks)
alw17 = self._extract_var_features_AlvLouWeh2017(sample) alw17 = self._extract_var_features_AlvLouWeh2017(sample)
sample.put_vector_list( sample.put_vector_list(
"lp_var_features", "lp_var_features",
self._combine( self._combine(
[ [
alw17, alw17,
sample.get_vector("lp_var_reduced_costs"), sample.get_array("lp_var_reduced_costs"),
sample.get_vector("lp_var_sa_lb_down"), sample.get_array("lp_var_sa_lb_down"),
sample.get_vector("lp_var_sa_lb_up"), sample.get_array("lp_var_sa_lb_up"),
sample.get_vector("lp_var_sa_obj_down"), sample.get_array("lp_var_sa_obj_down"),
sample.get_vector("lp_var_sa_obj_up"), sample.get_array("lp_var_sa_obj_up"),
sample.get_vector("lp_var_sa_ub_down"), sample.get_array("lp_var_sa_ub_down"),
sample.get_vector("lp_var_sa_ub_up"), sample.get_array("lp_var_sa_ub_up"),
sample.get_vector("lp_var_values"), sample.get_array("lp_var_values"),
sample.get_vector_list("static_var_features"), sample.get_vector_list("static_var_features"),
], ],
), ),
@ -106,10 +106,10 @@ class FeaturesExtractor:
self._combine( self._combine(
[ [
sample.get_vector_list("static_constr_features"), sample.get_vector_list("static_constr_features"),
sample.get_vector("lp_constr_dual_values"), sample.get_array("lp_constr_dual_values"),
sample.get_vector("lp_constr_sa_rhs_down"), sample.get_array("lp_constr_sa_rhs_down"),
sample.get_vector("lp_constr_sa_rhs_up"), sample.get_array("lp_constr_sa_rhs_up"),
sample.get_vector("lp_constr_slacks"), sample.get_array("lp_constr_slacks"),
], ],
), ),
) )
@ -131,8 +131,8 @@ class FeaturesExtractor:
) -> None: ) -> None:
variables = solver.get_variables(with_static=False, with_sa=False) variables = solver.get_variables(with_static=False, with_sa=False)
constraints = solver.get_constraints(with_static=False, with_sa=False) constraints = solver.get_constraints(with_static=False, with_sa=False)
sample.put_vector("mip_var_values", variables.values) sample.put_array("mip_var_values", variables.values)
sample.put_vector("mip_constr_slacks", constraints.slacks) sample.put_array("mip_constr_slacks", constraints.slacks)
def _extract_user_features_vars( def _extract_user_features_vars(
self, self,

@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
@ -39,11 +39,12 @@ class Sample(ABC):
@abstractmethod @abstractmethod
def get_bytes(self, key: str) -> Optional[Bytes]: def get_bytes(self, key: str) -> Optional[Bytes]:
pass warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod @abstractmethod
def put_bytes(self, key: str, value: Bytes) -> None: def put_bytes(self, key: str, value: Bytes) -> None:
pass warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod @abstractmethod
def get_scalar(self, key: str) -> Optional[Any]: def get_scalar(self, key: str) -> Optional[Any]:
@ -55,18 +56,28 @@ class Sample(ABC):
@abstractmethod @abstractmethod
def get_vector(self, key: str) -> Optional[Any]: def get_vector(self, key: str) -> Optional[Any]:
pass warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod @abstractmethod
def put_vector(self, key: str, value: Vector) -> None: def put_vector(self, key: str, value: Vector) -> None:
pass warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod @abstractmethod
def get_vector_list(self, key: str) -> Optional[Any]: def get_vector_list(self, key: str) -> Optional[Any]:
pass warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod @abstractmethod
def put_vector_list(self, key: str, value: VectorList) -> None: def put_vector_list(self, key: str, value: VectorList) -> None:
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
pass
@abstractmethod
def get_array(self, key: str) -> Optional[np.ndarray]:
pass pass
def get_set(self, key: str) -> Set: def get_set(self, key: str) -> Set:
@ -103,6 +114,10 @@ class Sample(ABC):
continue continue
self._assert_is_vector(v) self._assert_is_vector(v)
def _assert_supported(self, value: np.ndarray) -> None:
assert isinstance(value, np.ndarray)
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
class MemorySample(Sample): class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory.""" """Dictionary-like class that stores training data in-memory."""
@ -171,6 +186,17 @@ class MemorySample(Sample):
def _put(self, key: str, value: Any) -> None: def _put(self, key: str, value: Any) -> None:
self._data[key] = value self._data[key] = value
@overrides
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
if value is None:
return
self._assert_supported(value)
self._put(key, value)
@overrides
def get_array(self, key: str) -> Optional[np.ndarray]:
return cast(Optional[np.ndarray], self._get(key))
class Hdf5Sample(Sample): class Hdf5Sample(Sample):
""" """
@ -310,6 +336,21 @@ class Hdf5Sample(Sample):
ds = self.file.create_dataset(key, data=value) ds = self.file.create_dataset(key, data=value)
return ds return ds
@overrides
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
if value is None:
return
self._assert_supported(value)
if key in self.file:
del self.file[key]
return self.file.create_dataset(key, data=value, compression="gzip")
@overrides
def get_array(self, key: str) -> Optional[np.ndarray]:
if key not in self.file:
return None
return self.file[key][:]
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist) veclist = deepcopy(veclist)

@ -332,7 +332,7 @@ class GurobiSolver(InternalSolver):
obj_coeffs = self._var_obj_coeffs obj_coeffs = self._var_obj_coeffs
if self._has_lp_solution: if self._has_lp_solution:
reduced_costs = model.getAttr("rc", self._gp_vars) reduced_costs = np.array(model.getAttr("rc", self._gp_vars), dtype=float)
basis_status = list( basis_status = list(
map( map(
_parse_gurobi_vbasis, _parse_gurobi_vbasis,

@ -3,10 +3,10 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import Any from typing import Any
import numpy as np import numpy as np
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
from miplearn.solvers.tests import assert_equals
def test_memory_sample() -> None: def test_memory_sample() -> None:
@ -19,54 +19,29 @@ def test_hdf5_sample() -> None:
def _test_sample(sample: Sample) -> None: def _test_sample(sample: Sample) -> None:
# Scalar
_assert_roundtrip_scalar(sample, "A") _assert_roundtrip_scalar(sample, "A")
_assert_roundtrip_scalar(sample, True) _assert_roundtrip_scalar(sample, True)
_assert_roundtrip_scalar(sample, 1) _assert_roundtrip_scalar(sample, 1)
_assert_roundtrip_scalar(sample, 1.0) _assert_roundtrip_scalar(sample, 1.0)
_assert_roundtrip_array(sample, np.array([True, False], dtype="bool"))
# Vector _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16"))
_assert_roundtrip_vector(sample, ["A", "BB", "CCC", None]) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32"))
_assert_roundtrip_vector(sample, [True, True, False]) _assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int64"))
_assert_roundtrip_vector(sample, [1, 2, 3]) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float16"))
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0]) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32"))
_assert_roundtrip_vector(sample, np.array([1.0, 2.0, 3.0]), check_type=False) _assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64"))
_assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S"))
# VectorList
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
_assert_roundtrip_vector_list(sample, [[True], [False, False], None])
_assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
_assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
_assert_roundtrip_vector_list(sample, [None, None])
# Bytes
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
_assert_roundtrip_bytes(
sample,
bytearray(b"\x00\x01\x02\x03\x04\x05"),
check_type=False,
)
# Querying unknown keys should return None
assert sample.get_scalar("unknown-key") is None assert sample.get_scalar("unknown-key") is None
assert sample.get_vector("unknown-key") is None assert sample.get_array("unknown-key") is None
assert sample.get_vector_list("unknown-key") is None
assert sample.get_bytes("unknown-key") is None
# Putting None should not modify HDF5 file
sample.put_scalar("key", None)
sample.put_vector("key", None)
def _assert_roundtrip_bytes( def _assert_roundtrip_array(sample: Sample, expected: Any) -> None:
sample: Sample, expected: Any, check_type: bool = False sample.put_array("key", expected)
) -> None: actual = sample.get_array("key")
sample.put_bytes("key", expected)
actual = sample.get_bytes("key")
assert actual == expected
assert actual is not None assert actual is not None
if check_type: assert isinstance(actual, np.ndarray)
_assert_same_type(actual, expected) assert actual.dtype == expected.dtype
assert (actual == expected).all()
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
@ -74,91 +49,6 @@ def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
actual = sample.get_scalar("key") actual = sample.get_scalar("key")
assert actual == expected assert actual == expected
assert actual is not None assert actual is not None
_assert_same_type(actual, expected)
def _assert_roundtrip_vector(
sample: Sample, expected: Any, check_type: bool = True
) -> None:
sample.put_vector("key", expected)
actual = sample.get_vector("key")
assert_equals(actual, expected)
assert actual is not None
if check_type:
_assert_same_type(actual[0], expected[0])
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:
sample.put_vector_list("key", expected)
actual = sample.get_vector_list("key")
assert actual == expected
assert actual is not None
if actual[0] is not None:
_assert_same_type(actual[0][0], expected[0][0])
def _assert_same_type(actual: Any, expected: Any) -> None:
assert isinstance( assert isinstance(
actual, expected.__class__ actual, expected.__class__
), f"Expected {expected.__class__}, found {actual.__class__} instead" ), f"Expected {expected.__class__}, found {actual.__class__} instead"
def test_pad_int() -> None:
_assert_roundtrip_pad(
original=[[1], [2, 2, 2], [], [3, 3], [4, 4, 4, 4], None],
expected_padded=[
[1, 0, 0, 0],
[2, 2, 2, 0],
[0, 0, 0, 0],
[3, 3, 0, 0],
[4, 4, 4, 4],
[0, 0, 0, 0],
],
expected_lens=[1, 3, 0, 2, 4, -1],
dtype=int,
)
def test_pad_float() -> None:
_assert_roundtrip_pad(
original=[[1.0], [2.0, 2.0, 2.0], [3.0, 3.0], [4.0, 4.0, 4.0, 4.0], None],
expected_padded=[
[1.0, 0.0, 0.0, 0.0],
[2.0, 2.0, 2.0, 0.0],
[3.0, 3.0, 0.0, 0.0],
[4.0, 4.0, 4.0, 4.0],
[0.0, 0.0, 0.0, 0.0],
],
expected_lens=[1, 3, 2, 4, -1],
dtype=float,
)
def test_pad_str() -> None:
_assert_roundtrip_pad(
original=[["A"], ["B", "B", "B"], ["C", "C"]],
expected_padded=[["A", "", ""], ["B", "B", "B"], ["C", "C", ""]],
expected_lens=[1, 3, 2],
dtype=str,
)
def _assert_roundtrip_pad(
original: Any,
expected_padded: Any,
expected_lens: Any,
dtype: Any,
) -> None:
actual_padded, actual_lens = _pad(original)
assert actual_padded == expected_padded
assert actual_lens == expected_lens
for v in actual_padded:
for vi in v: # type: ignore
assert isinstance(vi, dtype)
cropped = _crop(actual_padded, actual_lens)
assert cropped == original
for v in cropped:
if v is None:
continue
for vi in v: # type: ignore
assert isinstance(vi, dtype)

Loading…
Cancel
Save