Implement CompositeComponent

pull/3/head
Alinson S. Xavier 5 years ago
parent 95672ad529
commit 94b493ac4b

@ -0,0 +1,44 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from miplearn import Component
class CompositeComponent(Component):
"""
A Component which redirects each method call to one or more subcomponents. Useful
for breaking down complex components into smaller classes. See RelaxationComponent
for a concrete example.
Parameters
----------
children : list[Component]
Subcomponents that compose this component.
"""
def __init__(self, children):
self.children = children
def before_solve(self, solver, instance, model):
for child in self.children:
child.before_solve(solver, instance, model)
def after_solve(self, solver, instance, model, results):
for child in self.children:
child.after_solve(solver, instance, model, results)
def fit(self, training_instances):
for child in self.children:
child.fit(training_instances)
def lazy_cb(self, solver, instance, model):
for child in self.children:
child.lazy_cb(solver, instance, model)
def iteration_cb(self, solver, instance, model):
should_repeat = False
for child in self.children:
if child.iteration_cb(solver, instance, model):
should_repeat = True
return should_repeat

@ -0,0 +1,57 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from unittest.mock import Mock, call
from nltk import Model
from miplearn import Component, LearningSolver, Instance
from miplearn.components.composite import CompositeComponent
def test_composite():
solver, instance, model = (
Mock(spec=LearningSolver),
Mock(spec=Instance),
Mock(spec=Model),
)
c1 = Mock(spec=Component)
c2 = Mock(spec=Component)
cc = CompositeComponent([c1, c2])
# Should broadcast before_solve
cc.before_solve(solver, instance, model)
c1.before_solve.assert_has_calls([call(solver, instance, model)])
c2.before_solve.assert_has_calls([call(solver, instance, model)])
# Should broadcast after_solve
cc.after_solve(solver, instance, model, {})
c1.after_solve.assert_has_calls([call(solver, instance, model, {})])
c2.after_solve.assert_has_calls([call(solver, instance, model, {})])
# Should broadcast fit
cc.fit([1, 2, 3])
c1.fit.assert_has_calls([call([1, 2, 3])])
c2.fit.assert_has_calls([call([1, 2, 3])])
# Should broadcast lazy_cb
cc.lazy_cb(solver, instance, model)
c1.lazy_cb.assert_has_calls([call(solver, instance, model)])
c2.lazy_cb.assert_has_calls([call(solver, instance, model)])
# Should broadcast iteration_cb
cc.iteration_cb(solver, instance, model)
c1.iteration_cb.assert_has_calls([call(solver, instance, model)])
c2.iteration_cb.assert_has_calls([call(solver, instance, model)])
# If at least one child component returns true, iteration_cb should return True
c1.iteration_cb = Mock(return_value=True)
c2.iteration_cb = Mock(return_value=False)
assert cc.iteration_cb(solver, instance, model)
# If all children return False, iteration_cb should return False
c1.iteration_cb = Mock(return_value=False)
c2.iteration_cb = Mock(return_value=False)
assert not cc.iteration_cb(solver, instance, model)
Loading…
Cancel
Save