Split Sample.{get,put} into {get,put}_{scalar,vector}

master
Alinson S. Xavier 4 years ago
parent 0a399deeee
commit 8fc7c6ab71
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -33,15 +33,15 @@ class FeaturesExtractor:
) -> None: ) -> None:
variables = solver.get_variables(with_static=True) variables = solver.get_variables(with_static=True)
constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs) constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs)
sample.put("var_lower_bounds", variables.lower_bounds) sample.put_vector("var_lower_bounds", variables.lower_bounds)
sample.put("var_names", variables.names) sample.put_vector("var_names", variables.names)
sample.put("var_obj_coeffs", variables.obj_coeffs) sample.put_vector("var_obj_coeffs", variables.obj_coeffs)
sample.put("var_types", variables.types) sample.put_vector("var_types", variables.types)
sample.put("var_upper_bounds", variables.upper_bounds) sample.put_vector("var_upper_bounds", variables.upper_bounds)
sample.put("constr_names", constraints.names) sample.put_vector("constr_names", constraints.names)
sample.put("constr_lhs", constraints.lhs) sample.put("constr_lhs", constraints.lhs)
sample.put("constr_rhs", constraints.rhs) sample.put_vector("constr_rhs", constraints.rhs)
sample.put("constr_senses", constraints.senses) sample.put_vector("constr_senses", constraints.senses)
self._extract_user_features_vars(instance, sample) self._extract_user_features_vars(instance, sample)
self._extract_user_features_constrs(instance, sample) self._extract_user_features_constrs(instance, sample)
self._extract_user_features_instance(instance, sample) self._extract_user_features_instance(instance, sample)
@ -67,20 +67,20 @@ class FeaturesExtractor:
) -> None: ) -> None:
variables = solver.get_variables(with_static=False, with_sa=self.with_sa) variables = solver.get_variables(with_static=False, with_sa=self.with_sa)
constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa) constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa)
sample.put("lp_var_basis_status", variables.basis_status) sample.put_vector("lp_var_basis_status", variables.basis_status)
sample.put("lp_var_reduced_costs", variables.reduced_costs) sample.put_vector("lp_var_reduced_costs", variables.reduced_costs)
sample.put("lp_var_sa_lb_down", variables.sa_lb_down) sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down)
sample.put("lp_var_sa_lb_up", variables.sa_lb_up) sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up)
sample.put("lp_var_sa_obj_down", variables.sa_obj_down) sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down)
sample.put("lp_var_sa_obj_up", variables.sa_obj_up) sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up)
sample.put("lp_var_sa_ub_down", variables.sa_ub_down) sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down)
sample.put("lp_var_sa_ub_up", variables.sa_ub_up) sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up)
sample.put("lp_var_values", variables.values) sample.put_vector("lp_var_values", variables.values)
sample.put("lp_constr_basis_status", constraints.basis_status) sample.put_vector("lp_constr_basis_status", constraints.basis_status)
sample.put("lp_constr_dual_values", constraints.dual_values) sample.put_vector("lp_constr_dual_values", constraints.dual_values)
sample.put("lp_constr_sa_rhs_down", constraints.sa_rhs_down) sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down)
sample.put("lp_constr_sa_rhs_up", constraints.sa_rhs_up) sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up)
sample.put("lp_constr_slacks", constraints.slacks) sample.put_vector("lp_constr_slacks", constraints.slacks)
self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_") self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_")
sample.put( sample.put(
"lp_var_features", "lp_var_features",
@ -134,8 +134,8 @@ class FeaturesExtractor:
) -> None: ) -> None:
variables = solver.get_variables(with_static=False, with_sa=False) variables = solver.get_variables(with_static=False, with_sa=False)
constraints = solver.get_constraints(with_static=False, with_sa=False) constraints = solver.get_constraints(with_static=False, with_sa=False)
sample.put("mip_var_values", variables.values) sample.put_vector("mip_var_values", variables.values)
sample.put("mip_constr_slacks", constraints.slacks) sample.put_vector("mip_constr_slacks", constraints.slacks)
def _extract_user_features_vars( def _extract_user_features_vars(
self, self,
@ -228,7 +228,7 @@ class FeaturesExtractor:
else: else:
lazy.append(False) lazy.append(False)
sample.put("constr_features_user", user_features) sample.put("constr_features_user", user_features)
sample.put("constr_lazy", lazy) sample.put_vector("constr_lazy", lazy)
sample.put("constr_categories", categories) sample.put("constr_categories", categories)
def _extract_user_features_instance( def _extract_user_features_instance(
@ -251,7 +251,7 @@ class FeaturesExtractor:
constr_lazy = sample.get("constr_lazy") constr_lazy = sample.get("constr_lazy")
assert constr_lazy is not None assert constr_lazy is not None
sample.put("instance_features_user", user_features) sample.put("instance_features_user", user_features)
sample.put("static_lazy_count", sum(constr_lazy)) sample.put_scalar("static_lazy_count", sum(constr_lazy))
# Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based # Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based
# approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195. # approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195.

@ -3,16 +3,34 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional, Any from typing import Dict, Optional, Any, Union, List
import h5py import h5py
import numpy as np
from overrides import overrides from overrides import overrides
Scalar = Union[None, bool, str, int, float]
Vector = Union[None, List[bool], List[str], List[int], List[float]]
class Sample(ABC): class Sample(ABC):
"""Abstract dictionary-like class that stores training data.""" """Abstract dictionary-like class that stores training data."""
@abstractmethod
def get_scalar(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_scalar(self, key: str, value: Scalar) -> None:
pass
@abstractmethod
def get_vector(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def put_vector(self, key: str, value: Vector) -> None:
pass
@abstractmethod @abstractmethod
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Optional[Any]:
pass pass
@ -33,6 +51,8 @@ class Sample(ABC):
def _is_primitive(v: Any) -> bool: def _is_primitive(v: Any) -> bool:
if isinstance(v, (str, bool, int, float)): if isinstance(v, (str, bool, int, float)):
return True return True
if v is None:
return True
return False return False
if _is_primitive(value): if _is_primitive(value):
@ -40,8 +60,23 @@ class Sample(ABC):
if isinstance(value, list): if isinstance(value, list):
if _is_primitive(value[0]): if _is_primitive(value[0]):
return return
if isinstance(value[0], list):
if _is_primitive(value[0][0]):
return
assert False, f"Value has unsupported type: {value}" assert False, f"Value has unsupported type: {value}"
def _assert_scalar(self, value: Any) -> None:
if value is None:
return
if isinstance(value, (str, bool, int, float)):
return
assert False, f"Scalar expected; found instead: {value}"
def _assert_vector(self, value: Any) -> None:
assert isinstance(value, list), f"List expected; found instead: {value}"
for v in value:
self._assert_scalar(v)
class MemorySample(Sample): class MemorySample(Sample):
"""Dictionary-like class that stores training data in-memory.""" """Dictionary-like class that stores training data in-memory."""
@ -54,6 +89,26 @@ class MemorySample(Sample):
data = {} data = {}
self._data: Dict[str, Any] = data self._data: Dict[str, Any] = data
@overrides
def get_scalar(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_scalar(self, key: str, value: Scalar) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def get_vector(self, key: str) -> Optional[Any]:
return self.get(key)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self.put(key, value)
@overrides @overrides
def get(self, key: str) -> Optional[Any]: def get(self, key: str) -> Optional[Any]:
if key in self._data: if key in self._data:
@ -63,7 +118,6 @@ class MemorySample(Sample):
@overrides @overrides
def put(self, key: str, value: Any) -> None: def put(self, key: str, value: Any) -> None:
# self._assert_supported(value)
self._data[key] = value self._data[key] = value
@ -79,22 +133,45 @@ class Hdf5Sample(Sample):
self.file = h5py.File(filename, "r+") self.file = h5py.File(filename, "r+")
@overrides @overrides
def get(self, key: str) -> Optional[Any]: def get_scalar(self, key: str) -> Optional[Any]:
ds = self.file[key] ds = self.file[key]
assert len(ds.shape) == 0
if h5py.check_string_dtype(ds.dtype): if h5py.check_string_dtype(ds.dtype):
if ds.shape == ():
return ds.asstr()[()] return ds.asstr()[()]
else: else:
return ds[()].tolist()
@overrides
def get_vector(self, key: str) -> Optional[Any]:
ds = self.file[key]
assert len(ds.shape) == 1
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist() return ds.asstr()[:].tolist()
else: else:
if ds.shape == (): return ds[:].tolist()
return ds[()].tolist()
@overrides
def put_scalar(self, key: str, value: Any) -> None:
self._assert_scalar(value)
self.put(key, value)
@overrides
def put_vector(self, key: str, value: Vector) -> None:
if value is None:
return
self._assert_vector(value)
self.put(key, value)
@overrides
def get(self, key: str) -> Optional[Any]:
ds = self.file[key]
if h5py.check_string_dtype(ds.dtype):
return ds.asstr()[:].tolist()
else: else:
return ds[:].tolist() return ds[:].tolist()
@overrides @overrides
def put(self, key: str, value: Any) -> None: def put(self, key: str, value: Any) -> None:
self._assert_supported(value)
if key in self.file: if key in self.file:
del self.file[key] del self.file[key]
self.file.create_dataset(key, data=value) self.file.create_dataset(key, data=value)

@ -210,7 +210,7 @@ class LearningSolver:
logger.info("Extracting features (after-lp)...") logger.info("Extracting features (after-lp)...")
initial_time = time.time() initial_time = time.time()
for (k, v) in lp_stats.__dict__.items(): for (k, v) in lp_stats.__dict__.items():
sample.put(k, v) sample.put_scalar(k, v)
self.extractor.extract_after_lp_features(self.internal_solver, sample) self.extractor.extract_after_lp_features(self.internal_solver, sample)
logger.info( logger.info(
"Features (after-lp) extracted in %.2f seconds" "Features (after-lp) extracted in %.2f seconds"
@ -278,7 +278,7 @@ class LearningSolver:
logger.info("Extracting features (after-mip)...") logger.info("Extracting features (after-mip)...")
initial_time = time.time() initial_time = time.time()
for (k, v) in mip_stats.__dict__.items(): for (k, v) in mip_stats.__dict__.items():
sample.put(k, v) sample.put_scalar(k, v)
self.extractor.extract_after_mip_features(self.internal_solver, sample) self.extractor.extract_after_mip_features(self.internal_solver, sample)
logger.info( logger.info(
"Features (after-mip) extracted in %.2f seconds" "Features (after-mip) extracted in %.2f seconds"

@ -7,34 +7,6 @@ from typing import Any
from miplearn.features.sample import MemorySample, Sample, Hdf5Sample from miplearn.features.sample import MemorySample, Sample, Hdf5Sample
def _test_sample(sample: Sample) -> None:
_assert_roundtrip(sample, "A")
_assert_roundtrip(sample, True)
_assert_roundtrip(sample, 1)
_assert_roundtrip(sample, 1.0)
_assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"])
_assert_roundtrip(sample, [True, True, False])
_assert_roundtrip(sample, [1, 2, 3])
_assert_roundtrip(sample, [1.0, 2.0, 3.0])
def _assert_roundtrip(sample: Sample, expected: Any) -> None:
sample.put("key", expected)
actual = sample.get("key")
assert actual == expected
assert actual is not None
if isinstance(actual, list):
assert isinstance(actual[0], expected[0].__class__), (
f"Expected class {expected[0].__class__}, "
f"found {actual[0].__class__} instead"
)
else:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
)
def test_memory_sample() -> None: def test_memory_sample() -> None:
_test_sample(MemorySample()) _test_sample(MemorySample())
@ -42,3 +14,51 @@ def test_memory_sample() -> None:
def test_hdf5_sample() -> None: def test_hdf5_sample() -> None:
file = NamedTemporaryFile() file = NamedTemporaryFile()
_test_sample(Hdf5Sample(file.name)) _test_sample(Hdf5Sample(file.name))
def _test_sample(sample: Sample) -> None:
# Scalar
_assert_roundtrip_scalar(sample, "A")
_assert_roundtrip_scalar(sample, True)
_assert_roundtrip_scalar(sample, 1)
_assert_roundtrip_scalar(sample, 1.0)
# Vector
_assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは"])
_assert_roundtrip_vector(sample, [True, True, False])
_assert_roundtrip_vector(sample, [1, 2, 3])
_assert_roundtrip_vector(sample, [1.0, 2.0, 3.0])
# List[Optional[List[Primitive]]]
# _assert_roundtrip(
# sample,
# [
# [1],
# None,
# [2, 2],
# [3, 3, 3],
# ],
# )
def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None:
sample.put_scalar("key", expected)
actual = sample.get_scalar("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual, expected)
def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None:
sample.put_vector("key", expected)
actual = sample.get_vector("key")
assert actual == expected
assert actual is not None
_assert_same_type(actual[0], expected[0])
def _assert_same_type(actual: Any, expected: Any) -> None:
assert isinstance(actual, expected.__class__), (
f"Expected class {expected.__class__}, "
f"found class {actual.__class__} instead"
)

Loading…
Cancel
Save