Split Sample.{get,put} into {get,put}_{scalar,vector}

master
Alinson S. Xavier 4 years ago
parent 0a399deeee
commit 8fc7c6ab71
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -33,15 +33,15 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put("var_lower_bounds", variables.lower_bounds)
sample.put("var_names", variables.names)
sample.put("var_obj_coeffs", variables.obj_coeffs)
sample.put("var_types", variables.types)
sample.put("var_upper_bounds", variables.upper_bounds)
sample.put("constr_names", constraints.names)
sample.put_vector("var_lower_bounds", variables.lower_bounds)
sample.put_vector("var_names", variables.names)
sample.put_vector("var_obj_coeffs", variables.obj_coeffs)
sample.put_vector("var_types", variables.types)
sample.put_vector("var_upper_bounds", variables.upper_bounds)
sample.put_vector("constr_names", constraints.names)
sample.put("constr_lhs", constraints.lhs)
sample.put("constr_rhs", constraints.rhs)
sample.put("constr_senses", constraints.senses)
sample.put_vector("constr_rhs", constraints.rhs)
sample.put_vector("constr_senses", constraints.senses)
self._extract_user_features_vars(instance, sample)
self._extract_user_features_constrs(instance, sample)
self._extract_user_features_instance(instance, sample)
@ -67,20 +67,20 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=False, with_sa=self.with_sa)
constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa)
sample.put("lp_var_basis_status", variables.basis_status)
sample.put("lp_var_reduced_costs", variables.reduced_costs)
sample.put("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put("lp_var_values", variables.values)
sample.put("lp_constr_basis_status", constraints.basis_status)
sample.put("lp_constr_dual_values", constraints.dual_values)
sample.put("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put("lp_constr_slacks", constraints.slacks)
sample.put_vector("lp_var_basis_status", variables.basis_status)
sample.put_vector("lp_var_reduced_costs", variables.reduced_costs)
sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put_vector("lp_var_values", variables.values)
sample.put_vector("lp_constr_basis_status", constraints.basis_status)
sample.put_vector("lp_constr_dual_values", constraints.dual_values)
sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put_vector("lp_constr_slacks", constraints.slacks)
self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_")
sample.put(
"lp_var_features",
@ -134,8 +134,8 @@ class FeaturesExtractor:
) -> None:
variables = solver.get_variables(with_static=False, with_sa=False)
constraints = solver.get_constraints(with_static=False, with_sa=False)
sample.put("mip_var_values", variables.values)
sample.put("mip_constr_slacks", constraints.slacks)
sample.put_vector("mip_var_values", variables.values)
sample.put_vector("mip_constr_slacks", constraints.slacks)
def _extract_user_features_vars(
self,
@ -228,7 +228,7 @@ class FeaturesExtractor:
else:
lazy.append(False)
sample.put("constr_features_user", user_features)
sample.put("constr_lazy", lazy)
sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories)
def _extract_user_features_instance(
@ -251,7 +251,7 @@ class FeaturesExtractor:
constr_lazy = sample.get("constr_lazy")
assert constr_lazy is not None
sample.put("instance_features_user", user_features)
sample.put("static_lazy_count", sum(constr_lazy))
sample.put_scalar("static_lazy_count", sum(constr_lazy))
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
# approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195.

@ -3,16 +3,34 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
from typing import Dict, Optional, Any, Union, List
import h5py
import numpy as np
from overrides import overrides
Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]]
class Sample(ABC):
"""Abstract dictionary-like class that stores training data."""
@abstractmethod
def get_scalar(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_scalar(self, key: str, value: Scalar) -> None:
pass
@abstractmethod
def get_vector(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_vector(self, key: str, value: Vector) -> None:
pass
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@ -33,6 +51,8 @@ class Sample(ABC):
def _is_primitive(v: Any) -> bool:
if isinstance(v, (str, bool, int, float)):
return True
if v is None:
return True
return False
if _is_primitive(value):
@ -40,8 +60,23 @@ class Sample(ABC):
if isinstance(value, list):
if _is_primitive(value[0]):
return
if isinstance(value[0], list):
if _is_primitive(value[0][0]):
return
assert False, f"Value has unsupported type: {value}"
def _assert_scalar(self, value: Any) -> None:
if value is None:
return
if isinstance(value, (str, bool, int, float)):
return
assert False, f"Scalar expected; found instead: {value}"
def _assert_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
self._assert_scalar(v)
class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory."""
@ -54,6 +89,26 @@ class MemorySample(Sample):
data = {}
self._data: Dict[str, Any] = data
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def get_vector(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self.put(key, value)
@overrides
def get(self, key: str) -> Optional[Any]:
if key in self._data:
@ -63,7 +118,6 @@ class MemorySample(Sample):
@overrides
def put(self, key: str, value: Any) -> None:
# self._assert_supported(value)
self._data[key] = value
@ -79,22 +133,45 @@ class Hdf5Sample(Sample):
self.file = h5py.File(filename, "r+")
@overrides
def get(self, key: str) -> Optional[Any]:
def get_scalar(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 0
if h5py.check_string_dtype(ds.dtype):
if ds.shape == ():
return ds.asstr()[()]
else:
return ds[()].tolist()
@overrides
def get_vector(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 1
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
if ds.shape == ():
return ds[()].tolist()
return ds[:].tolist()
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self.put(key, value)
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else:
return ds[:].tolist()
@overrides
def put(self, key: str, value: Any) -> None:
self._assert_supported(value)
if key in self.file:
del self.file[key]
self.file.create_dataset(key, data=value)

@ -210,7 +210,7 @@ class LearningSolver:
logger.info("Extracting features (after-lp)...")
initial_time = time.time()
for (k, v) in lp_stats.__dict__.items():
sample.put(k, v)
sample.put_scalar(k, v)
self.extractor.extract_after_lp_features(self.internal_solver, sample)
logger.info(
"Features (after-lp) extracted in %.2f seconds"
@ -278,7 +278,7 @@ class LearningSolver:
logger.info("Extracting features (after-mip)...")
initial_time = time.time()
for (k, v) in mip_stats.__dict__.items():
sample.put(k, v)
sample.put_scalar(k, v)
self.extractor.extract_after_mip_features(self.internal_solver, sample)
logger.info(
"Features (after-mip) extracted in %.2f seconds"

@ -7,34 +7,6 @@ from typing import Any
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
def _test_sample(sample: Sample) -> None:
_assert_roundtrip(sample, "A")
_assert_roundtrip(sample, True)
_assert_roundtrip(sample, 1)
_assert_roundtrip(sample, 1.0)
_assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"])
_assert_roundtrip(sample, [True, True, False])
_assert_roundtrip(sample, [1, 2, 3])
_assert_roundtrip(sample, [1.0, 2.0, 3.0])
def _assert_roundtrip(sample: Sample, expected: Any) -> None:
sample.put("key", expected)
actual = sample.get("key")
assert actual == expected
assert actual is not None
if isinstance(actual, list):
assert isinstance(actual[0], expected[0].__class__), (
f"Expected class {expected[0].__class__}, "
f"found {actual[0].__class__} instead"
)
else:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
)
def test_memory_sample() -> None:
_test_sample(MemorySample())
@ -42,3 +14,51 @@ def test_memory_sample() -> None:
def test_hdf5_sample() -> None:
file = NamedTemporaryFile()
_test_sample(Hdf5Sample(file.name))
def _test_sample(sample: Sample) -> None:
# Scalar
_assert_roundtrip_scalar(sample, "A")
_assert_roundtrip_scalar(sample, True)
_assert_roundtrip_scalar(sample, 1)
_assert_roundtrip_scalar(sample, 1.0)
# Vector
_assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは"])
_assert_roundtrip_vector(sample, [True, True, False])
_assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
# List[Optional[List[Primitive]]]
# _assert_roundtrip(
# sample,
# [
# [1],
# None,
# [2, 2],
# [3, 3, 3],
# ],
# )
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
sample.put_scalar("key", expected)
actual = sample.get_scalar("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual, expected)
def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None:
sample.put_vector("key", expected)
actual = sample.get_vector("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual[0], expected[0])
def _assert_same_type(actual: Any, expected: Any) -> None:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
)

Loading…
Cancel
Save