diff --git a/miplearn/instance/file.py b/miplearn/instance/file.py new file mode 100644 index 0000000..f0e5664 --- /dev/null +++ b/miplearn/instance/file.py @@ -0,0 +1,132 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. +import gc +import os +from typing import Any, Optional, List, Dict, TYPE_CHECKING +import pickle + +from overrides import overrides + +from miplearn.features.sample import Hdf5Sample, Sample +from miplearn.instance.base import Instance + +if TYPE_CHECKING: + from miplearn.solvers.learning import InternalSolver + + +class FileInstance(Instance): + def __init__(self, filename: str) -> None: + super().__init__() + assert os.path.exists(filename), f"File not found: {filename}" + self.h5 = Hdf5Sample(filename) + self.instance: Optional[Instance] = None + + # Delegation + # ------------------------------------------------------------------------- + @overrides + def to_model(self) -> Any: + assert self.instance is not None + return self.instance.to_model() + + @overrides + def get_instance_features(self) -> List[float]: + assert self.instance is not None + return self.instance.get_instance_features() + + @overrides + def get_variable_features(self) -> Dict[str, List[float]]: + assert self.instance is not None + return self.instance.get_variable_features() + + @overrides + def get_variable_categories(self) -> Dict[str, str]: + assert self.instance is not None + return self.instance.get_variable_categories() + + @overrides + def get_constraint_features(self) -> Dict[str, List[float]]: + assert self.instance is not None + return self.instance.get_constraint_features() + + @overrides + def get_constraint_categories(self) -> Dict[str, str]: + assert self.instance is not None + return self.instance.get_constraint_categories() + + @overrides + def has_static_lazy_constraints(self) -> bool: + assert self.instance is not None + return self.instance.has_static_lazy_constraints() + + @overrides + def has_dynamic_lazy_constraints(self) -> bool: + assert self.instance is not None + return self.instance.has_dynamic_lazy_constraints() + + @overrides + def is_constraint_lazy(self, cid: str) -> bool: + assert self.instance is not None + return self.instance.is_constraint_lazy(cid) + + @overrides + def find_violated_lazy_constraints( + self, + solver: "InternalSolver", + model: Any, + ) -> List[str]: + assert self.instance is not None + return self.instance.find_violated_lazy_constraints(solver, model) + + @overrides + def enforce_lazy_constraint( + self, + solver: "InternalSolver", + model: Any, + violation: str, + ) -> None: + assert self.instance is not None + self.instance.enforce_lazy_constraint(solver, model, violation) + + @overrides + def find_violated_user_cuts(self, model: Any) -> List[str]: + assert self.instance is not None + return self.instance.find_violated_user_cuts(model) + + @overrides + def enforce_user_cut( + self, + solver: "InternalSolver", + model: Any, + violation: str, + ) -> None: + assert self.instance is not None + self.instance.enforce_user_cut(solver, model, violation) + + # Input & Output + # ------------------------------------------------------------------------- + @overrides + def free(self) -> None: + self.instance = None + gc.collect() + + @overrides + def load(self) -> None: + if self.instance is not None: + return + self.instance = pickle.loads(self.h5.get_bytes("pickled")) + assert isinstance(self.instance, Instance) + + @classmethod + def save(cls, instance: Instance, filename: str) -> None: + h5 = Hdf5Sample(filename) + instance_pkl = pickle.dumps(instance) + h5.put_bytes("pickled", instance_pkl) + + @overrides + def create_sample(self) -> Sample: + return self.h5 + + @overrides + def get_samples(self) -> List[Sample]: + return [self.h5] diff --git a/tests/instance/test_file.py b/tests/instance/test_file.py new file mode 100644 index 0000000..14c7c73 --- /dev/null +++ b/tests/instance/test_file.py @@ -0,0 +1,32 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +import tempfile + +from miplearn.solvers.learning import LearningSolver +from miplearn.solvers.gurobi import GurobiSolver +from miplearn.features.sample import Hdf5Sample +from miplearn.instance.file import FileInstance + + +def test_usage() -> None: + # Create original instance + original = GurobiSolver().build_test_instance_knapsack() + + # Save instance to disk + file = tempfile.NamedTemporaryFile() + FileInstance.save(original, file.name) + sample = Hdf5Sample(file.name) + assert len(sample.get_bytes("pickled")) > 0 + + # Solve instance from disk + solver = LearningSolver(solver=GurobiSolver()) + solver.solve(FileInstance(file.name)) + + # Assert HDF5 contains training data + sample = FileInstance(file.name).get_samples()[0] + assert sample.get_scalar("mip_lower_bound") == 1183.0 + assert sample.get_scalar("mip_upper_bound") == 1183.0 + assert len(sample.get_vector("lp_var_values")) == 5 + assert len(sample.get_vector("mip_var_values")) == 5