mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Fix RedirectOutput; add tests
This commit is contained in:
@@ -2,20 +2,15 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RedirectOutput(object):
|
||||
|
||||
class RedirectOutput:
|
||||
def __init__(self, streams):
|
||||
self.streams = streams
|
||||
self._original_stdout = sys.stdout
|
||||
self._original_stderr = sys.stderr
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
|
||||
def __del__(self):
|
||||
sys.stdout = self._original_stdout
|
||||
sys.stderr = self._original_stderr
|
||||
|
||||
def write(self, data):
|
||||
for stream in self.streams:
|
||||
@@ -26,7 +21,12 @@ class RedirectOutput(object):
|
||||
stream.flush()
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
self._original_stdout = sys.stdout
|
||||
self._original_stderr = sys.stderr
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
return self
|
||||
|
||||
def __exit__(self, _type, _value, _traceback):
|
||||
pass
|
||||
sys.stdout = self._original_stdout
|
||||
sys.stderr = self._original_stderr
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from io import StringIO
|
||||
|
||||
import pyomo.environ as pe
|
||||
from miplearn.solvers import RedirectOutput
|
||||
from miplearn.solvers.cplex import CPLEXSolver
|
||||
from miplearn.solvers.gurobi import GurobiSolver
|
||||
|
||||
@@ -10,6 +12,16 @@ from . import _get_instance
|
||||
from ...problems.knapsack import ChallengeA
|
||||
|
||||
|
||||
def test_redirect_output():
|
||||
import sys
|
||||
original_stdout = sys.stdout
|
||||
io = StringIO()
|
||||
with RedirectOutput([io]):
|
||||
print("Hello world")
|
||||
assert sys.stdout == original_stdout
|
||||
assert io.getvalue() == "Hello world\n"
|
||||
|
||||
|
||||
def test_internal_solver_warm_starts():
|
||||
for solver in [GurobiSolver(), CPLEXSolver()]:
|
||||
instance = _get_instance()
|
||||
|
||||
Reference in New Issue
Block a user