parent
64a63264c7
commit
07388d9490
@ -1,52 +0,0 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from miplearn.components.component import Component
|
||||
|
||||
|
||||
class CompositeComponent(Component):
|
||||
"""
|
||||
A Component which redirects each method call to one or more subcomponents.
|
||||
|
||||
Useful for breaking down complex components into smaller classes. See
|
||||
RelaxationComponent for a concrete example.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
children : list[Component]
|
||||
Subcomponents that compose this component.
|
||||
"""
|
||||
|
||||
def __init__(self, children):
|
||||
self.children = children
|
||||
|
||||
def before_solve_mip(self, solver, instance, model):
|
||||
for child in self.children:
|
||||
child.before_solve_mip(solver, instance, model)
|
||||
|
||||
def after_solve_mip(
|
||||
self,
|
||||
solver,
|
||||
instance,
|
||||
model,
|
||||
stats,
|
||||
training_data,
|
||||
):
|
||||
for child in self.children:
|
||||
child.after_solve_mip(solver, instance, model, stats, training_data)
|
||||
|
||||
def fit(self, training_instances):
|
||||
for child in self.children:
|
||||
child.fit(training_instances)
|
||||
|
||||
def lazy_cb(self, solver, instance, model):
|
||||
for child in self.children:
|
||||
child.lazy_cb(solver, instance, model)
|
||||
|
||||
def iteration_cb(self, solver, instance, model):
|
||||
should_repeat = False
|
||||
for child in self.children:
|
||||
if child.iteration_cb(solver, instance, model):
|
||||
should_repeat = True
|
||||
return should_repeat
|
@ -1,57 +0,0 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.components.composite import CompositeComponent
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
|
||||
|
||||
def test_composite():
|
||||
solver, instance, model = (
|
||||
Mock(spec=LearningSolver),
|
||||
Mock(spec=Instance),
|
||||
Mock(),
|
||||
)
|
||||
|
||||
c1 = Mock(spec=Component)
|
||||
c2 = Mock(spec=Component)
|
||||
cc = CompositeComponent([c1, c2])
|
||||
|
||||
# Should broadcast before_solve
|
||||
cc.before_solve_mip(solver, instance, model)
|
||||
c1.before_solve_mip.assert_has_calls([call(solver, instance, model)])
|
||||
c2.before_solve_mip.assert_has_calls([call(solver, instance, model)])
|
||||
|
||||
# Should broadcast after_solve
|
||||
cc.after_solve_mip(solver, instance, model, {}, {})
|
||||
c1.after_solve_mip.assert_has_calls([call(solver, instance, model, {}, {})])
|
||||
c2.after_solve_mip.assert_has_calls([call(solver, instance, model, {}, {})])
|
||||
|
||||
# Should broadcast fit
|
||||
cc.fit([1, 2, 3])
|
||||
c1.fit.assert_has_calls([call([1, 2, 3])])
|
||||
c2.fit.assert_has_calls([call([1, 2, 3])])
|
||||
|
||||
# Should broadcast lazy_cb
|
||||
cc.lazy_cb(solver, instance, model)
|
||||
c1.lazy_cb.assert_has_calls([call(solver, instance, model)])
|
||||
c2.lazy_cb.assert_has_calls([call(solver, instance, model)])
|
||||
|
||||
# Should broadcast iteration_cb
|
||||
cc.iteration_cb(solver, instance, model)
|
||||
c1.iteration_cb.assert_has_calls([call(solver, instance, model)])
|
||||
c2.iteration_cb.assert_has_calls([call(solver, instance, model)])
|
||||
|
||||
# If at least one child component returns true, iteration_cb should return True
|
||||
c1.iteration_cb = Mock(return_value=True)
|
||||
c2.iteration_cb = Mock(return_value=False)
|
||||
assert cc.iteration_cb(solver, instance, model)
|
||||
|
||||
# If all children return False, iteration_cb should return False
|
||||
c1.iteration_cb = Mock(return_value=False)
|
||||
c2.iteration_cb = Mock(return_value=False)
|
||||
assert not cc.iteration_cb(solver, instance, model)
|
Loading…
Reference in new issue