parent
64a63264c7
commit
07388d9490
@ -1,52 +0,0 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
|
||||||
|
|
||||||
from miplearn.components.component import Component
|
|
||||||
|
|
||||||
|
|
||||||
class CompositeComponent(Component):
|
|
||||||
"""
|
|
||||||
A Component which redirects each method call to one or more subcomponents.
|
|
||||||
|
|
||||||
Useful for breaking down complex components into smaller classes. See
|
|
||||||
RelaxationComponent for a concrete example.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
children : list[Component]
|
|
||||||
Subcomponents that compose this component.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, children):
|
|
||||||
self.children = children
|
|
||||||
|
|
||||||
def before_solve_mip(self, solver, instance, model):
|
|
||||||
for child in self.children:
|
|
||||||
child.before_solve_mip(solver, instance, model)
|
|
||||||
|
|
||||||
def after_solve_mip(
|
|
||||||
self,
|
|
||||||
solver,
|
|
||||||
instance,
|
|
||||||
model,
|
|
||||||
stats,
|
|
||||||
training_data,
|
|
||||||
):
|
|
||||||
for child in self.children:
|
|
||||||
child.after_solve_mip(solver, instance, model, stats, training_data)
|
|
||||||
|
|
||||||
def fit(self, training_instances):
|
|
||||||
for child in self.children:
|
|
||||||
child.fit(training_instances)
|
|
||||||
|
|
||||||
def lazy_cb(self, solver, instance, model):
|
|
||||||
for child in self.children:
|
|
||||||
child.lazy_cb(solver, instance, model)
|
|
||||||
|
|
||||||
def iteration_cb(self, solver, instance, model):
|
|
||||||
should_repeat = False
|
|
||||||
for child in self.children:
|
|
||||||
if child.iteration_cb(solver, instance, model):
|
|
||||||
should_repeat = True
|
|
||||||
return should_repeat
|
|
@ -1,57 +0,0 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
|
||||||
|
|
||||||
from unittest.mock import Mock, call
|
|
||||||
|
|
||||||
from miplearn.components.component import Component
|
|
||||||
from miplearn.components.composite import CompositeComponent
|
|
||||||
from miplearn.instance import Instance
|
|
||||||
from miplearn.solvers.learning import LearningSolver
|
|
||||||
|
|
||||||
|
|
||||||
def test_composite():
|
|
||||||
solver, instance, model = (
|
|
||||||
Mock(spec=LearningSolver),
|
|
||||||
Mock(spec=Instance),
|
|
||||||
Mock(),
|
|
||||||
)
|
|
||||||
|
|
||||||
c1 = Mock(spec=Component)
|
|
||||||
c2 = Mock(spec=Component)
|
|
||||||
cc = CompositeComponent([c1, c2])
|
|
||||||
|
|
||||||
# Should broadcast before_solve
|
|
||||||
cc.before_solve_mip(solver, instance, model)
|
|
||||||
c1.before_solve_mip.assert_has_calls([call(solver, instance, model)])
|
|
||||||
c2.before_solve_mip.assert_has_calls([call(solver, instance, model)])
|
|
||||||
|
|
||||||
# Should broadcast after_solve
|
|
||||||
cc.after_solve_mip(solver, instance, model, {}, {})
|
|
||||||
c1.after_solve_mip.assert_has_calls([call(solver, instance, model, {}, {})])
|
|
||||||
c2.after_solve_mip.assert_has_calls([call(solver, instance, model, {}, {})])
|
|
||||||
|
|
||||||
# Should broadcast fit
|
|
||||||
cc.fit([1, 2, 3])
|
|
||||||
c1.fit.assert_has_calls([call([1, 2, 3])])
|
|
||||||
c2.fit.assert_has_calls([call([1, 2, 3])])
|
|
||||||
|
|
||||||
# Should broadcast lazy_cb
|
|
||||||
cc.lazy_cb(solver, instance, model)
|
|
||||||
c1.lazy_cb.assert_has_calls([call(solver, instance, model)])
|
|
||||||
c2.lazy_cb.assert_has_calls([call(solver, instance, model)])
|
|
||||||
|
|
||||||
# Should broadcast iteration_cb
|
|
||||||
cc.iteration_cb(solver, instance, model)
|
|
||||||
c1.iteration_cb.assert_has_calls([call(solver, instance, model)])
|
|
||||||
c2.iteration_cb.assert_has_calls([call(solver, instance, model)])
|
|
||||||
|
|
||||||
# If at least one child component returns true, iteration_cb should return True
|
|
||||||
c1.iteration_cb = Mock(return_value=True)
|
|
||||||
c2.iteration_cb = Mock(return_value=False)
|
|
||||||
assert cc.iteration_cb(solver, instance, model)
|
|
||||||
|
|
||||||
# If all children return False, iteration_cb should return False
|
|
||||||
c1.iteration_cb = Mock(return_value=False)
|
|
||||||
c2.iteration_cb = Mock(return_value=False)
|
|
||||||
assert not cc.iteration_cb(solver, instance, model)
|
|
Loading…
Reference in new issue