# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. from unittest.mock import Mock, call from miplearn.components.component import Component from miplearn.components.composite import CompositeComponent from miplearn.instance import Instance from miplearn.solvers.learning import LearningSolver def test_composite(): solver, instance, model = ( Mock(spec=LearningSolver), Mock(spec=Instance), Mock(), ) c1 = Mock(spec=Component) c2 = Mock(spec=Component) cc = CompositeComponent([c1, c2]) # Should broadcast before_solve cc.before_solve(solver, instance, model) c1.before_solve.assert_has_calls([call(solver, instance, model)]) c2.before_solve.assert_has_calls([call(solver, instance, model)]) # Should broadcast after_solve cc.after_solve(solver, instance, model, {}, {}) c1.after_solve.assert_has_calls([call(solver, instance, model, {}, {})]) c2.after_solve.assert_has_calls([call(solver, instance, model, {}, {})]) # Should broadcast fit cc.fit([1, 2, 3]) c1.fit.assert_has_calls([call([1, 2, 3])]) c2.fit.assert_has_calls([call([1, 2, 3])]) # Should broadcast lazy_cb cc.lazy_cb(solver, instance, model) c1.lazy_cb.assert_has_calls([call(solver, instance, model)]) c2.lazy_cb.assert_has_calls([call(solver, instance, model)]) # Should broadcast iteration_cb cc.iteration_cb(solver, instance, model) c1.iteration_cb.assert_has_calls([call(solver, instance, model)]) c2.iteration_cb.assert_has_calls([call(solver, instance, model)]) # If at least one child component returns true, iteration_cb should return True c1.iteration_cb = Mock(return_value=True) c2.iteration_cb = Mock(return_value=False) assert cc.iteration_cb(solver, instance, model) # If all children return False, iteration_cb should return False c1.iteration_cb = Mock(return_value=False) c2.iteration_cb = Mock(return_value=False) assert not cc.iteration_cb(solver, instance, model)