diff --git a/miplearn/components/composite.py b/miplearn/components/composite.py new file mode 100644 index 0000000..b349643 --- /dev/null +++ b/miplearn/components/composite.py @@ -0,0 +1,44 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +from miplearn import Component + + +class CompositeComponent(Component): + """ + A Component which redirects each method call to one or more subcomponents. Useful + for breaking down complex components into smaller classes. See RelaxationComponent + for a concrete example. + + Parameters + ---------- + children : list[Component] + Subcomponents that compose this component. + """ + + def __init__(self, children): + self.children = children + + def before_solve(self, solver, instance, model): + for child in self.children: + child.before_solve(solver, instance, model) + + def after_solve(self, solver, instance, model, results): + for child in self.children: + child.after_solve(solver, instance, model, results) + + def fit(self, training_instances): + for child in self.children: + child.fit(training_instances) + + def lazy_cb(self, solver, instance, model): + for child in self.children: + child.lazy_cb(solver, instance, model) + + def iteration_cb(self, solver, instance, model): + should_repeat = False + for child in self.children: + if child.iteration_cb(solver, instance, model): + should_repeat = True + return should_repeat diff --git a/miplearn/components/tests/test_composite.py b/miplearn/components/tests/test_composite.py new file mode 100644 index 0000000..8e8a24b --- /dev/null +++ b/miplearn/components/tests/test_composite.py @@ -0,0 +1,57 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +from unittest.mock import Mock, call + +from nltk import Model + +from miplearn import Component, LearningSolver, Instance +from miplearn.components.composite import CompositeComponent + + +def test_composite(): + solver, instance, model = ( + Mock(spec=LearningSolver), + Mock(spec=Instance), + Mock(spec=Model), + ) + + c1 = Mock(spec=Component) + c2 = Mock(spec=Component) + cc = CompositeComponent([c1, c2]) + + # Should broadcast before_solve + cc.before_solve(solver, instance, model) + c1.before_solve.assert_has_calls([call(solver, instance, model)]) + c2.before_solve.assert_has_calls([call(solver, instance, model)]) + + # Should broadcast after_solve + cc.after_solve(solver, instance, model, {}) + c1.after_solve.assert_has_calls([call(solver, instance, model, {})]) + c2.after_solve.assert_has_calls([call(solver, instance, model, {})]) + + # Should broadcast fit + cc.fit([1, 2, 3]) + c1.fit.assert_has_calls([call([1, 2, 3])]) + c2.fit.assert_has_calls([call([1, 2, 3])]) + + # Should broadcast lazy_cb + cc.lazy_cb(solver, instance, model) + c1.lazy_cb.assert_has_calls([call(solver, instance, model)]) + c2.lazy_cb.assert_has_calls([call(solver, instance, model)]) + + # Should broadcast iteration_cb + cc.iteration_cb(solver, instance, model) + c1.iteration_cb.assert_has_calls([call(solver, instance, model)]) + c2.iteration_cb.assert_has_calls([call(solver, instance, model)]) + + # If at least one child component returns true, iteration_cb should return True + c1.iteration_cb = Mock(return_value=True) + c2.iteration_cb = Mock(return_value=False) + assert cc.iteration_cb(solver, instance, model) + + # If all children return False, iteration_cb should return False + c1.iteration_cb = Mock(return_value=False) + c2.iteration_cb = Mock(return_value=False) + assert not cc.iteration_cb(solver, instance, model)