diff --git a/miplearn/features/extractor.py b/miplearn/features/extractor.py index b26f97f..6648f6d 100644 --- a/miplearn/features/extractor.py +++ b/miplearn/features/extractor.py @@ -33,15 +33,15 @@ class FeaturesExtractor: ) -> None: variables = solver.get_variables(with_static=True) constraints = solver.get_constraints(with_static=True, with_lhs=self.with_lhs) - sample.put("var_lower_bounds", variables.lower_bounds) - sample.put("var_names", variables.names) - sample.put("var_obj_coeffs", variables.obj_coeffs) - sample.put("var_types", variables.types) - sample.put("var_upper_bounds", variables.upper_bounds) - sample.put("constr_names", constraints.names) + sample.put_vector("var_lower_bounds", variables.lower_bounds) + sample.put_vector("var_names", variables.names) + sample.put_vector("var_obj_coeffs", variables.obj_coeffs) + sample.put_vector("var_types", variables.types) + sample.put_vector("var_upper_bounds", variables.upper_bounds) + sample.put_vector("constr_names", constraints.names) sample.put("constr_lhs", constraints.lhs) - sample.put("constr_rhs", constraints.rhs) - sample.put("constr_senses", constraints.senses) + sample.put_vector("constr_rhs", constraints.rhs) + sample.put_vector("constr_senses", constraints.senses) self._extract_user_features_vars(instance, sample) self._extract_user_features_constrs(instance, sample) self._extract_user_features_instance(instance, sample) @@ -67,20 +67,20 @@ class FeaturesExtractor: ) -> None: variables = solver.get_variables(with_static=False, with_sa=self.with_sa) constraints = solver.get_constraints(with_static=False, with_sa=self.with_sa) - sample.put("lp_var_basis_status", variables.basis_status) - sample.put("lp_var_reduced_costs", variables.reduced_costs) - sample.put("lp_var_sa_lb_down", variables.sa_lb_down) - sample.put("lp_var_sa_lb_up", variables.sa_lb_up) - sample.put("lp_var_sa_obj_down", variables.sa_obj_down) - sample.put("lp_var_sa_obj_up", variables.sa_obj_up) - sample.put("lp_var_sa_ub_down", variables.sa_ub_down) - sample.put("lp_var_sa_ub_up", variables.sa_ub_up) - sample.put("lp_var_values", variables.values) - sample.put("lp_constr_basis_status", constraints.basis_status) - sample.put("lp_constr_dual_values", constraints.dual_values) - sample.put("lp_constr_sa_rhs_down", constraints.sa_rhs_down) - sample.put("lp_constr_sa_rhs_up", constraints.sa_rhs_up) - sample.put("lp_constr_slacks", constraints.slacks) + sample.put_vector("lp_var_basis_status", variables.basis_status) + sample.put_vector("lp_var_reduced_costs", variables.reduced_costs) + sample.put_vector("lp_var_sa_lb_down", variables.sa_lb_down) + sample.put_vector("lp_var_sa_lb_up", variables.sa_lb_up) + sample.put_vector("lp_var_sa_obj_down", variables.sa_obj_down) + sample.put_vector("lp_var_sa_obj_up", variables.sa_obj_up) + sample.put_vector("lp_var_sa_ub_down", variables.sa_ub_down) + sample.put_vector("lp_var_sa_ub_up", variables.sa_ub_up) + sample.put_vector("lp_var_values", variables.values) + sample.put_vector("lp_constr_basis_status", constraints.basis_status) + sample.put_vector("lp_constr_dual_values", constraints.dual_values) + sample.put_vector("lp_constr_sa_rhs_down", constraints.sa_rhs_down) + sample.put_vector("lp_constr_sa_rhs_up", constraints.sa_rhs_up) + sample.put_vector("lp_constr_slacks", constraints.slacks) self._extract_var_features_AlvLouWeh2017(sample, prefix="lp_") sample.put( "lp_var_features", @@ -134,8 +134,8 @@ class FeaturesExtractor: ) -> None: variables = solver.get_variables(with_static=False, with_sa=False) constraints = solver.get_constraints(with_static=False, with_sa=False) - sample.put("mip_var_values", variables.values) - sample.put("mip_constr_slacks", constraints.slacks) + sample.put_vector("mip_var_values", variables.values) + sample.put_vector("mip_constr_slacks", constraints.slacks) def _extract_user_features_vars( self, @@ -228,7 +228,7 @@ class FeaturesExtractor: else: lazy.append(False) sample.put("constr_features_user", user_features) - sample.put("constr_lazy", lazy) + sample.put_vector("constr_lazy", lazy) sample.put("constr_categories", categories) def _extract_user_features_instance( @@ -251,7 +251,7 @@ class FeaturesExtractor: constr_lazy = sample.get("constr_lazy") assert constr_lazy is not None sample.put("instance_features_user", user_features) - sample.put("static_lazy_count", sum(constr_lazy)) + sample.put_scalar("static_lazy_count", sum(constr_lazy)) # Alvarez, A. M., Louveaux, Q., & Wehenkel, L. (2017). A machine learning-based # approximation of strong branching. INFORMS Journal on Computing, 29(1), 185-195. diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index bd23d9b..27810c3 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -3,16 +3,34 @@ # Released under the modified BSD license. See COPYING.md for more details. from abc import ABC, abstractmethod -from typing import Dict, Optional, Any +from typing import Dict, Optional, Any, Union, List import h5py -import numpy as np from overrides import overrides +Scalar = Union[None, bool, str, int, float] +Vector = Union[None, List[bool], List[str], List[int], List[float]] + class Sample(ABC): """Abstract dictionary-like class that stores training data.""" + @abstractmethod + def get_scalar(self, key: str) -> Optional[Any]: + pass + + @abstractmethod + def put_scalar(self, key: str, value: Scalar) -> None: + pass + + @abstractmethod + def get_vector(self, key: str) -> Optional[Any]: + pass + + @abstractmethod + def put_vector(self, key: str, value: Vector) -> None: + pass + @abstractmethod def get(self, key: str) -> Optional[Any]: pass @@ -33,6 +51,8 @@ class Sample(ABC): def _is_primitive(v: Any) -> bool: if isinstance(v, (str, bool, int, float)): return True + if v is None: + return True return False if _is_primitive(value): @@ -40,8 +60,23 @@ class Sample(ABC): if isinstance(value, list): if _is_primitive(value[0]): return + if isinstance(value[0], list): + if _is_primitive(value[0][0]): + return assert False, f"Value has unsupported type: {value}" + def _assert_scalar(self, value: Any) -> None: + if value is None: + return + if isinstance(value, (str, bool, int, float)): + return + assert False, f"Scalar expected; found instead: {value}" + + def _assert_vector(self, value: Any) -> None: + assert isinstance(value, list), f"List expected; found instead: {value}" + for v in value: + self._assert_scalar(v) + class MemorySample(Sample): """Dictionary-like class that stores training data in-memory.""" @@ -54,6 +89,26 @@ class MemorySample(Sample): data = {} self._data: Dict[str, Any] = data + @overrides + def get_scalar(self, key: str) -> Optional[Any]: + return self.get(key) + + @overrides + def put_scalar(self, key: str, value: Scalar) -> None: + self._assert_scalar(value) + self.put(key, value) + + @overrides + def get_vector(self, key: str) -> Optional[Any]: + return self.get(key) + + @overrides + def put_vector(self, key: str, value: Vector) -> None: + if value is None: + return + self._assert_vector(value) + self.put(key, value) + @overrides def get(self, key: str) -> Optional[Any]: if key in self._data: @@ -63,7 +118,6 @@ class MemorySample(Sample): @overrides def put(self, key: str, value: Any) -> None: - # self._assert_supported(value) self._data[key] = value @@ -78,23 +132,46 @@ class Hdf5Sample(Sample): def __init__(self, filename: str) -> None: self.file = h5py.File(filename, "r+") + @overrides + def get_scalar(self, key: str) -> Optional[Any]: + ds = self.file[key] + assert len(ds.shape) == 0 + if h5py.check_string_dtype(ds.dtype): + return ds.asstr()[()] + else: + return ds[()].tolist() + + @overrides + def get_vector(self, key: str) -> Optional[Any]: + ds = self.file[key] + assert len(ds.shape) == 1 + if h5py.check_string_dtype(ds.dtype): + return ds.asstr()[:].tolist() + else: + return ds[:].tolist() + + @overrides + def put_scalar(self, key: str, value: Any) -> None: + self._assert_scalar(value) + self.put(key, value) + + @overrides + def put_vector(self, key: str, value: Vector) -> None: + if value is None: + return + self._assert_vector(value) + self.put(key, value) + @overrides def get(self, key: str) -> Optional[Any]: ds = self.file[key] if h5py.check_string_dtype(ds.dtype): - if ds.shape == (): - return ds.asstr()[()] - else: - return ds.asstr()[:].tolist() + return ds.asstr()[:].tolist() else: - if ds.shape == (): - return ds[()].tolist() - else: - return ds[:].tolist() + return ds[:].tolist() @overrides def put(self, key: str, value: Any) -> None: - self._assert_supported(value) if key in self.file: del self.file[key] self.file.create_dataset(key, data=value) diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index a072514..efb10b7 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -210,7 +210,7 @@ class LearningSolver: logger.info("Extracting features (after-lp)...") initial_time = time.time() for (k, v) in lp_stats.__dict__.items(): - sample.put(k, v) + sample.put_scalar(k, v) self.extractor.extract_after_lp_features(self.internal_solver, sample) logger.info( "Features (after-lp) extracted in %.2f seconds" @@ -278,7 +278,7 @@ class LearningSolver: logger.info("Extracting features (after-mip)...") initial_time = time.time() for (k, v) in mip_stats.__dict__.items(): - sample.put(k, v) + sample.put_scalar(k, v) self.extractor.extract_after_mip_features(self.internal_solver, sample) logger.info( "Features (after-mip) extracted in %.2f seconds" diff --git a/tests/features/test_sample.py b/tests/features/test_sample.py index 3cdb5e7..64b74c8 100644 --- a/tests/features/test_sample.py +++ b/tests/features/test_sample.py @@ -7,34 +7,6 @@ from typing import Any from miplearn.features.sample import MemorySample, Sample, Hdf5Sample -def _test_sample(sample: Sample) -> None: - _assert_roundtrip(sample, "A") - _assert_roundtrip(sample, True) - _assert_roundtrip(sample, 1) - _assert_roundtrip(sample, 1.0) - _assert_roundtrip(sample, ["A", "BB", "CCC", "こんにちは"]) - _assert_roundtrip(sample, [True, True, False]) - _assert_roundtrip(sample, [1, 2, 3]) - _assert_roundtrip(sample, [1.0, 2.0, 3.0]) - - -def _assert_roundtrip(sample: Sample, expected: Any) -> None: - sample.put("key", expected) - actual = sample.get("key") - assert actual == expected - assert actual is not None - if isinstance(actual, list): - assert isinstance(actual[0], expected[0].__class__), ( - f"Expected class {expected[0].__class__}, " - f"found {actual[0].__class__} instead" - ) - else: - assert isinstance(actual, expected.__class__), ( - f"Expected class {expected.__class__}, " - f"found class {actual.__class__} instead" - ) - - def test_memory_sample() -> None: _test_sample(MemorySample()) @@ -42,3 +14,51 @@ def test_memory_sample() -> None: def test_hdf5_sample() -> None: file = NamedTemporaryFile() _test_sample(Hdf5Sample(file.name)) + + +def _test_sample(sample: Sample) -> None: + # Scalar + _assert_roundtrip_scalar(sample, "A") + _assert_roundtrip_scalar(sample, True) + _assert_roundtrip_scalar(sample, 1) + _assert_roundtrip_scalar(sample, 1.0) + + # Vector + _assert_roundtrip_vector(sample, ["A", "BB", "CCC", "こんにちは"]) + _assert_roundtrip_vector(sample, [True, True, False]) + _assert_roundtrip_vector(sample, [1, 2, 3]) + _assert_roundtrip_vector(sample, [1.0, 2.0, 3.0]) + + # List[Optional[List[Primitive]]] + # _assert_roundtrip( + # sample, + # [ + # [1], + # None, + # [2, 2], + # [3, 3, 3], + # ], + # ) + + +def _assert_roundtrip_scalar(sample: Sample, expected: Any) -> None: + sample.put_scalar("key", expected) + actual = sample.get_scalar("key") + assert actual == expected + assert actual is not None + _assert_same_type(actual, expected) + + +def _assert_roundtrip_vector(sample: Sample, expected: Any) -> None: + sample.put_vector("key", expected) + actual = sample.get_vector("key") + assert actual == expected + assert actual is not None + _assert_same_type(actual[0], expected[0]) + + +def _assert_same_type(actual: Any, expected: Any) -> None: + assert isinstance(actual, expected.__class__), ( + f"Expected class {expected.__class__}, " + f"found class {actual.__class__} instead" + )