Implement {get,put}_array; make other methods deprecated

master
Alinson S. Xavier 4 years ago
parent 0a32586bf8
commit f69067aafd
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -33,14 +33,14 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put_vector("static_var_lower_bounds", variables.lower_bounds)
sample.put_array("static_var_lower_bounds", variables.lower_bounds)
sample.put_vector("static_var_names", variables.names)
sample.put_vector("static_var_obj_coeffs", variables.obj_coeffs)
sample.put_array("static_var_obj_coeffs", variables.obj_coeffs)
sample.put_vector("static_var_types", variables.types)
sample.put_vector("static_var_upper_bounds", variables.upper_bounds)
sample.put_array("static_var_upper_bounds", variables.upper_bounds)
sample.put_vector("static_constr_names", constraints.names)
# sample.put("static_constr_lhs", constraints.lhs)
sample.put_vector("static_constr_rhs", constraints.rhs)
sample.put_array("static_constr_rhs", constraints.rhs)
sample.put_vector("static_constr_senses", constraints.senses)
vars_features_user, var_categories = self._extract_user_features_vars(
instance, sample
@ -55,9 +55,9 @@ class FeaturesExtractor:
[
alw17,
vars_features_user,
sample.get_vector("static_var_lower_bounds"),
sample.get_vector("static_var_obj_coeffs"),
sample.get_vector("static_var_upper_bounds"),
sample.get_array("static_var_lower_bounds"),
sample.get_array("static_var_obj_coeffs"),
sample.get_array("static_var_upper_bounds"),
],
),
)
@ -70,33 +70,33 @@ class FeaturesExtractor:
variables = solver.get_variables(with_static=False, with_sa=self.with_sa)
constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa)
sample.put_vector("lp_var_basis_status", variables.basis_status)
sample.put_vector("lp_var_reduced_costs", variables.reduced_costs)
sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_vector("lp_var_values", variables.values)
sample.put_array("lp_var_reduced_costs", variables.reduced_costs)
sample.put_array("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_array("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_array("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_array("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_array("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_array("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_array("lp_var_values", variables.values)
sample.put_vector("lp_constr_basis_status", constraints.basis_status)
sample.put_vector("lp_constr_dual_values", constraints.dual_values)
sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks)
sample.put_array("lp_constr_dual_values", constraints.dual_values)
sample.put_array("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_array("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_array("lp_constr_slacks", constraints.slacks)
alw17 = self._extract_var_features_AlvLouWeh2017(sample)
sample.put_vector_list(
"lp_var_features",
self._combine(
[
alw17,
sample.get_vector("lp_var_reduced_costs"),
sample.get_vector("lp_var_sa_lb_down"),
sample.get_vector("lp_var_sa_lb_up"),
sample.get_vector("lp_var_sa_obj_down"),
sample.get_vector("lp_var_sa_obj_up"),
sample.get_vector("lp_var_sa_ub_down"),
sample.get_vector("lp_var_sa_ub_up"),
sample.get_vector("lp_var_values"),
sample.get_array("lp_var_reduced_costs"),
sample.get_array("lp_var_sa_lb_down"),
sample.get_array("lp_var_sa_lb_up"),
sample.get_array("lp_var_sa_obj_down"),
sample.get_array("lp_var_sa_obj_up"),
sample.get_array("lp_var_sa_ub_down"),
sample.get_array("lp_var_sa_ub_up"),
sample.get_array("lp_var_values"),
sample.get_vector_list("static_var_features"),
],
),
@ -106,10 +106,10 @@ class FeaturesExtractor:
self._combine(
[
sample.get_vector_list("static_constr_features"),
sample.get_vector("lp_constr_dual_values"),
sample.get_vector("lp_constr_sa_rhs_down"),
sample.get_vector("lp_constr_sa_rhs_up"),
sample.get_vector("lp_constr_slacks"),
sample.get_array("lp_constr_dual_values"),
sample.get_array("lp_constr_sa_rhs_down"),
sample.get_array("lp_constr_sa_rhs_up"),
sample.get_array("lp_constr_slacks"),
],
),
)
@ -131,8 +131,8 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=False, with_sa=False)
constraints = solver.get_constraints(with_static=False, with_sa=False)
sample.put_vector("mip_var_values", variables.values)
sample.put_vector("mip_constr_slacks", constraints.slacks)
sample.put_array("mip_var_values", variables.values)
sample.put_array("mip_constr_slacks", constraints.slacks)
def _extract_user_features_vars(
self,

@ -1,7 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Optional, Any, Union, List, Tuple, cast, Set
@ -39,11 +39,12 @@ class Sample(ABC):
@abstractmethod
def get_bytes(self, key: str) -> Optional[Bytes]:
pass
warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod
def put_bytes(self, key: str, value: Bytes) -> None:
pass
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def get_scalar(self, key: str) -> Optional[Any]:
@ -55,18 +56,28 @@ class Sample(ABC):
@abstractmethod
def get_vector(self, key: str) -> Optional[Any]:
pass
warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod
def put_vector(self, key: str, value: Vector) -> None:
pass
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def get_vector_list(self, key: str) -> Optional[Any]:
pass
warnings.warn("Deprecated", DeprecationWarning)
return None
@abstractmethod
def put_vector_list(self, key: str, value: VectorList) -> None:
warnings.warn("Deprecated", DeprecationWarning)
@abstractmethod
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
pass
@abstractmethod
def get_array(self, key: str) -> Optional[np.ndarray]:
pass
def get_set(self, key: str) -> Set:
@ -103,6 +114,10 @@ class Sample(ABC):
continue
self._assert_is_vector(v)
def _assert_supported(self, value: np.ndarray) -> None:
assert isinstance(value, np.ndarray)
assert value.dtype.kind in "biufS", f"Unsupported dtype: {value.dtype}"
class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory."""
@ -171,6 +186,17 @@ class MemorySample(Sample):
def _put(self, key: str, value: Any) -> None:
self._data[key] = value
@overrides
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
if value is None:
return
self._assert_supported(value)
self._put(key, value)
@overrides
def get_array(self, key: str) -> Optional[np.ndarray]:
return cast(Optional[np.ndarray], self._get(key))
class Hdf5Sample(Sample):
"""
@ -310,6 +336,21 @@ class Hdf5Sample(Sample):
ds = self.file.create_dataset(key, data=value)
return ds
@overrides
def put_array(self, key: str, value: Optional[np.ndarray]) -> None:
if value is None:
return
self._assert_supported(value)
if key in self.file:
del self.file[key]
return self.file.create_dataset(key, data=value, compression="gzip")
@overrides
def get_array(self, key: str) -> Optional[np.ndarray]:
if key not in self.file:
return None
return self.file[key][:]
def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]:
veclist = deepcopy(veclist)

@ -332,7 +332,7 @@ class GurobiSolver(InternalSolver):
obj_coeffs = self._var_obj_coeffs
if self._has_lp_solution:
reduced_costs = model.getAttr("rc", self._gp_vars)
reduced_costs = np.array(model.getAttr("rc", self._gp_vars), dtype=float)
basis_status = list(
map(
_parse_gurobi_vbasis,

@ -3,10 +3,10 @@
# Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile
from typing import Any
import numpy as np
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample, _pad, _crop
from miplearn.solvers.tests import assert_equals
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
def test_memory_sample() -> None:
@ -19,54 +19,29 @@ def test_hdf5_sample() -> None:
def _test_sample(sample: Sample) -> None:
# Scalar
_assert_roundtrip_scalar(sample, "A")
_assert_roundtrip_scalar(sample, True)
_assert_roundtrip_scalar(sample, 1)
_assert_roundtrip_scalar(sample, 1.0)
# Vector
_assert_roundtrip_vector(sample, ["A", "BB", "CCC", None])
_assert_roundtrip_vector(sample, [True, True, False])
_assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
_assert_roundtrip_vector(sample, np.array([1.0, 2.0, 3.0]), check_type=False)
# VectorList
_assert_roundtrip_vector_list(sample, [["A"], ["BB", "CCC"], None])
_assert_roundtrip_vector_list(sample, [[True], [False, False], None])
_assert_roundtrip_vector_list(sample, [[1], None, [2, 2], [3, 3, 3]])
_assert_roundtrip_vector_list(sample, [[1.0], None, [2.0, 2.0], [3.0, 3.0, 3.0]])
_assert_roundtrip_vector_list(sample, [None, None])
# Bytes
_assert_roundtrip_bytes(sample, b"\x00\x01\x02\x03\x04\x05")
_assert_roundtrip_bytes(
sample,
bytearray(b"\x00\x01\x02\x03\x04\x05"),
check_type=False,
)
# Querying unknown keys should return None
_assert_roundtrip_array(sample, np.array([True, False], dtype="bool"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int16"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int32"))
_assert_roundtrip_array(sample, np.array([1, 2, 3], dtype="int64"))
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float16"))
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float32"))
_assert_roundtrip_array(sample, np.array([1.0, 2.0, 3.0], dtype="float64"))
_assert_roundtrip_array(sample, np.array(["A", "BB", "CCC"], dtype="S"))
assert sample.get_scalar("unknown-key") is None
assert sample.get_vector("unknown-key") is None
assert sample.get_vector_list("unknown-key") is None
assert sample.get_bytes("unknown-key") is None
# Putting None should not modify HDF5 file
sample.put_scalar("key", None)
sample.put_vector("key", None)
assert sample.get_array("unknown-key") is None
def _assert_roundtrip_bytes(
sample: Sample, expected: Any, check_type: bool = False
) -> None:
sample.put_bytes("key", expected)
actual = sample.get_bytes("key")
assert actual == expected
def _assert_roundtrip_array(sample: Sample, expected: Any) -> None:
sample.put_array("key", expected)
actual = sample.get_array("key")
assert actual is not None
if check_type:
_assert_same_type(actual, expected)
assert isinstance(actual, np.ndarray)
assert actual.dtype == expected.dtype
assert (actual == expected).all()
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
@ -74,91 +49,6 @@ def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
actual = sample.get_scalar("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual, expected)
def _assert_roundtrip_vector(
sample: Sample, expected: Any, check_type: bool = True
) -> None:
sample.put_vector("key", expected)
actual = sample.get_vector("key")
assert_equals(actual, expected)
assert actual is not None
if check_type:
_assert_same_type(actual[0], expected[0])
def _assert_roundtrip_vector_list(sample: Sample, expected: Any) -> None:
sample.put_vector_list("key", expected)
actual = sample.get_vector_list("key")
assert actual == expected
assert actual is not None
if actual[0] is not None:
_assert_same_type(actual[0][0], expected[0][0])
def _assert_same_type(actual: Any, expected: Any) -> None:
assert isinstance(
actual, expected.__class__
), f"Expected {expected.__class__}, found {actual.__class__} instead"
def test_pad_int() -> None:
_assert_roundtrip_pad(
original=[[1], [2, 2, 2], [], [3, 3], [4, 4, 4, 4], None],
expected_padded=[
[1, 0, 0, 0],
[2, 2, 2, 0],
[0, 0, 0, 0],
[3, 3, 0, 0],
[4, 4, 4, 4],
[0, 0, 0, 0],
],
expected_lens=[1, 3, 0, 2, 4, -1],
dtype=int,
)
def test_pad_float() -> None:
_assert_roundtrip_pad(
original=[[1.0], [2.0, 2.0, 2.0], [3.0, 3.0], [4.0, 4.0, 4.0, 4.0], None],
expected_padded=[
[1.0, 0.0, 0.0, 0.0],
[2.0, 2.0, 2.0, 0.0],
[3.0, 3.0, 0.0, 0.0],
[4.0, 4.0, 4.0, 4.0],
[0.0, 0.0, 0.0, 0.0],
],
expected_lens=[1, 3, 2, 4, -1],
dtype=float,
)
def test_pad_str() -> None:
_assert_roundtrip_pad(
original=[["A"], ["B", "B", "B"], ["C", "C"]],
expected_padded=[["A", "", ""], ["B", "B", "B"], ["C", "C", ""]],
expected_lens=[1, 3, 2],
dtype=str,
)
def _assert_roundtrip_pad(
original: Any,
expected_padded: Any,
expected_lens: Any,
dtype: Any,
) -> None:
actual_padded, actual_lens = _pad(original)
assert actual_padded == expected_padded
assert actual_lens == expected_lens
for v in actual_padded:
for vi in v: # type: ignore
assert isinstance(vi, dtype)
cropped = _crop(actual_padded, actual_lens)
assert cropped == original
for v in cropped:
if v is None:
continue
for vi in v: # type: ignore
assert isinstance(vi, dtype)

Loading…
Cancel
Save