diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index 422c979..11f37b4 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -110,10 +110,12 @@ class MemorySample(Sample): def __init__( self, data: Optional[Dict[str, Any]] = None, + check_data: bool = False, ) -> None: if data is None: data = {} self._data: Dict[str, Any] = data + self._check_data = check_data @overrides def get_bytes(self, key: str) -> Optional[Bytes]: @@ -142,19 +144,22 @@ class MemorySample(Sample): def put_scalar(self, key: str, value: Scalar) -> None: if value is None: return - self._assert_is_scalar(value) + if self._check_data: + self._assert_is_scalar(value) self._put(key, value) @overrides def put_vector(self, key: str, value: Vector) -> None: if value is None: return - self._assert_is_vector(value) + if self._check_data: + self._assert_is_vector(value) self._put(key, value) @overrides def put_vector_list(self, key: str, value: VectorList) -> None: - self._assert_is_vector_list(value) + if self._check_data: + self._assert_is_vector_list(value) self._put(key, value) def _get(self, key: str) -> Optional[Any]: @@ -175,8 +180,14 @@ class Hdf5Sample(Sample): are actually accessed, and therefore it is more scalable. """ - def __init__(self, filename: str, mode: str = "r+") -> None: + def __init__( + self, + filename: str, + mode: str = "r+", + check_data: bool = False, + ) -> None: self.file = h5py.File(filename, mode, libver="latest") + self._check_data = check_data @overrides def get_bytes(self, key: str) -> Optional[Bytes]: @@ -230,27 +241,30 @@ class Hdf5Sample(Sample): @overrides def put_bytes(self, key: str, value: Bytes) -> None: - assert isinstance( - value, (bytes, bytearray) - ), f"bytes expected; found: {value}" # type: ignore + if self._check_data: + assert isinstance( + value, (bytes, bytearray) + ), f"bytes expected; found: {value}" # type: ignore self._put(key, np.frombuffer(value, dtype="uint8"), compress=True) @overrides def put_scalar(self, key: str, value: Any) -> None: if value is None: return - self._assert_is_scalar(value) + if self._check_data: + self._assert_is_scalar(value) self._put(key, value) @overrides def put_vector(self, key: str, value: Vector) -> None: if value is None: return - self._assert_is_vector(value) + if self._check_data: + self._assert_is_vector(value) for v in value: # Convert strings to bytes - if isinstance(v, str): + if isinstance(v, str) or v is None: value = np.array( [u if u is not None else b"" for u in value], dtype="S", @@ -266,7 +280,8 @@ class Hdf5Sample(Sample): @overrides def put_vector_list(self, key: str, value: VectorList) -> None: - self._assert_is_vector_list(value) + if self._check_data: + self._assert_is_vector_list(value) padded, lens = _pad(value) self.put_vector(f"{key}_lengths", lens) data = None @@ -297,7 +312,6 @@ class Hdf5Sample(Sample): def _pad(veclist: VectorList) -> Tuple[VectorList, List[int]]: - veclist = deepcopy(veclist) lens = [len(v) if v is not None else -1 for v in veclist] maxlen = max(lens) diff --git a/tests/features/test_extractor.py b/tests/features/test_extractor.py index 1053b4b..9edb496 100644 --- a/tests/features/test_extractor.py +++ b/tests/features/test_extractor.py @@ -2,13 +2,20 @@ # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. +import sys +import time +from typing import Any + import numpy as np +import gurobipy as gp from miplearn.features.extractor import FeaturesExtractor -from miplearn.features.sample import Sample, MemorySample -from miplearn.solvers.internal import Variables, Constraints +from miplearn.features.sample import MemorySample, Hdf5Sample +from miplearn.instance.base import Instance from miplearn.solvers.gurobi import GurobiSolver +from miplearn.solvers.internal import Variables, Constraints from miplearn.solvers.tests import assert_equals +import cProfile inf = float("inf") @@ -166,3 +173,27 @@ def test_assert_equals() -> None: assert_equals(np.array([True, True]), [True, True]) assert_equals((1.0,), (1.0,)) assert_equals({"x": 10}, {"x": 10}) + + +class MpsInstance(Instance): + def __init__(self, filename: str) -> None: + super().__init__() + self.filename = filename + + def to_model(self) -> Any: + return gp.read(self.filename) + + +if __name__ == "__main__": + solver = GurobiSolver() + instance = MpsInstance(sys.argv[1]) + solver.set_instance(instance) + solver.solve_lp(tee=True) + extractor = FeaturesExtractor(with_lhs=False) + sample = Hdf5Sample("tmp/prof.h5", mode="w") + + def run(): + extractor.extract_after_load_features(instance, solver, sample) + extractor.extract_after_lp_features(solver, sample) + + cProfile.run("run()", filename="tmp/prof")