diff --git a/miplearn/components/cuts/expert.py b/miplearn/components/cuts/expert.py new file mode 100644 index 0000000..b4c12c6 --- /dev/null +++ b/miplearn/components/cuts/expert.py @@ -0,0 +1,35 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +import json +import logging +from typing import Dict, Any, List + +from miplearn.components.cuts.mem import convert_lists_to_tuples +from miplearn.h5 import H5File +from miplearn.solvers.abstract import AbstractModel + +logger = logging.getLogger(__name__) + + +class ExpertCutsComponent: + def fit( + self, + _: List[str], + ) -> None: + pass + + def before_mip( + self, + test_h5: str, + model: AbstractModel, + stats: Dict[str, Any], + ) -> None: + with H5File(test_h5, "r") as h5: + cuts_str = h5.get_scalar("mip_cuts") + assert cuts_str is not None + assert isinstance(cuts_str, str) + cuts = list(set(convert_lists_to_tuples(json.loads(cuts_str)))) + model.set_cuts(cuts) + stats["Cuts: AOT"] = len(cuts) diff --git a/miplearn/components/lazy/expert.py b/miplearn/components/lazy/expert.py new file mode 100644 index 0000000..3d6e9fe --- /dev/null +++ b/miplearn/components/lazy/expert.py @@ -0,0 +1,36 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +import json +import logging +from typing import Dict, Any, List + +from miplearn.components.cuts.mem import convert_lists_to_tuples +from miplearn.h5 import H5File +from miplearn.solvers.abstract import AbstractModel + +logger = logging.getLogger(__name__) + + +class ExpertLazyComponent: + def fit( + self, + _: List[str], + ) -> None: + pass + + def before_mip( + self, + test_h5: str, + model: AbstractModel, + stats: Dict[str, Any], + ) -> None: + with H5File(test_h5, "r") as h5: + violations_str = h5.get_scalar("mip_lazy") + assert violations_str is not None + assert isinstance(violations_str, str) + violations = list(set(convert_lists_to_tuples(json.loads(violations_str)))) + logger.info(f"Enforcing {len(violations)} constraints ahead-of-time...") + model.lazy_enforce(violations) + stats["Lazy Constraints: AOT"] = len(violations)