diff --git a/src/python/miplearn/solvers.py b/src/python/miplearn/solvers.py index 090d855..db02560 100644 --- a/src/python/miplearn/solvers.py +++ b/src/python/miplearn/solvers.py @@ -3,8 +3,10 @@ # Released under the modified BSD license. See COPYING.md for more details. import logging +import sys from abc import ABC from copy import deepcopy +from io import StringIO import pyomo.core.kernel.objective import pyomo.environ as pe @@ -40,6 +42,33 @@ def _parallel_solve(instance_idx): } +class RedirectOutput(object): + def __init__(self, streams): + self.streams = streams + self._original_stdout = sys.stdout + self._original_stderr = sys.stderr + sys.stdout = self + sys.stderr = self + + def __del__(self): + sys.stdout = self._original_stdout + sys.stderr = self._original_stderr + + def write(self, data): + for stream in self.streams: + stream.write(data) + + def flush(self): + for stream in self.streams: + stream.flush() + + def __enter__(self): + pass + + def __exit__(self, _type, _value, _traceback): + pass + + class InternalSolver(ABC): """ The MIP solver used internaly by LearningSolver. @@ -217,13 +246,18 @@ class InternalSolver(ABC): ------- dict A dictionary of solver statistics containing the following keys: - "Lower bound", "Upper bound", "Wallclock time", "Nodes" and "Sense". + "Lower bound", "Upper bound", "Wallclock time", "Nodes", "Sense" + and "Log". """ total_wallclock_time = 0 + streams = [StringIO()] + if tee: + streams += [sys.stdout] self.instance.found_violations = [] while True: logger.debug("Solving MIP...") - results = self._pyomo_solver.solve(tee=tee) + with RedirectOutput(streams): + results = self._pyomo_solver.solve(tee=True) total_wallclock_time += results["Solver"][0]["Wallclock time"] if not hasattr(self.instance, "find_violations"): break @@ -243,6 +277,7 @@ class InternalSolver(ABC): "Wallclock time": total_wallclock_time, "Nodes": 1, "Sense": self._obj_sense, + "Log": streams[0].getvalue() } @@ -280,8 +315,13 @@ class GurobiSolver(InternalSolver): self._pyomo_solver.set_callback(cb) self.instance.found_violations = [] print(self._is_warm_start_available) - results = self._pyomo_solver.solve(tee=tee, - warmstart=self._is_warm_start_available) + + streams = [StringIO()] + if tee: + streams += [sys.stdout] + with RedirectOutput(streams): + results = self._pyomo_solver.solve(tee=True, + warmstart=self._is_warm_start_available) self._pyomo_solver.set_callback(None) node_count = int(self._pyomo_solver._solver_model.getAttr("NodeCount")) return { @@ -290,6 +330,7 @@ class GurobiSolver(InternalSolver): "Wallclock time": results["Solver"][0]["Wallclock time"], "Nodes": max(1, node_count), "Sense": self._obj_sense, + "Log": streams[0].getvalue(), } diff --git a/src/python/miplearn/tests/test_solver.py b/src/python/miplearn/tests/test_solver.py index 5da9104..1b61971 100644 --- a/src/python/miplearn/tests/test_solver.py +++ b/src/python/miplearn/tests/test_solver.py @@ -34,7 +34,8 @@ def test_internal_solver(): } }) - stats = solver.solve() + stats = solver.solve(tee=True) + assert len(stats["Log"]) > 100 assert stats["Lower bound"] == 1183.0 assert stats["Upper bound"] == 1183.0 assert stats["Sense"] == "max"