You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
2.1 KiB
56 lines
2.1 KiB
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
from unittest.mock import Mock, call
|
|
|
|
from miplearn import Component, LearningSolver, Instance
|
|
from miplearn.components.composite import CompositeComponent
|
|
|
|
|
|
def test_composite():
|
|
solver, instance, model = (
|
|
Mock(spec=LearningSolver),
|
|
Mock(spec=Instance),
|
|
Mock(),
|
|
)
|
|
|
|
c1 = Mock(spec=Component)
|
|
c2 = Mock(spec=Component)
|
|
cc = CompositeComponent([c1, c2])
|
|
|
|
# Should broadcast before_solve
|
|
cc.before_solve(solver, instance, model)
|
|
c1.before_solve.assert_has_calls([call(solver, instance, model)])
|
|
c2.before_solve.assert_has_calls([call(solver, instance, model)])
|
|
|
|
# Should broadcast after_solve
|
|
cc.after_solve(solver, instance, model, {}, {})
|
|
c1.after_solve.assert_has_calls([call(solver, instance, model, {}, {})])
|
|
c2.after_solve.assert_has_calls([call(solver, instance, model, {}, {})])
|
|
|
|
# Should broadcast fit
|
|
cc.fit([1, 2, 3])
|
|
c1.fit.assert_has_calls([call([1, 2, 3])])
|
|
c2.fit.assert_has_calls([call([1, 2, 3])])
|
|
|
|
# Should broadcast lazy_cb
|
|
cc.lazy_cb(solver, instance, model)
|
|
c1.lazy_cb.assert_has_calls([call(solver, instance, model)])
|
|
c2.lazy_cb.assert_has_calls([call(solver, instance, model)])
|
|
|
|
# Should broadcast iteration_cb
|
|
cc.iteration_cb(solver, instance, model)
|
|
c1.iteration_cb.assert_has_calls([call(solver, instance, model)])
|
|
c2.iteration_cb.assert_has_calls([call(solver, instance, model)])
|
|
|
|
# If at least one child component returns true, iteration_cb should return True
|
|
c1.iteration_cb = Mock(return_value=True)
|
|
c2.iteration_cb = Mock(return_value=False)
|
|
assert cc.iteration_cb(solver, instance, model)
|
|
|
|
# If all children return False, iteration_cb should return False
|
|
c1.iteration_cb = Mock(return_value=False)
|
|
c2.iteration_cb = Mock(return_value=False)
|
|
assert not cc.iteration_cb(solver, instance, model)
|