mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Fix RedirectOutput; add tests
This commit is contained in:
@@ -2,20 +2,15 @@
|
|||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class RedirectOutput(object):
|
|
||||||
|
class RedirectOutput:
|
||||||
def __init__(self, streams):
|
def __init__(self, streams):
|
||||||
self.streams = streams
|
self.streams = streams
|
||||||
self._original_stdout = sys.stdout
|
|
||||||
self._original_stderr = sys.stderr
|
|
||||||
sys.stdout = self
|
|
||||||
sys.stderr = self
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
sys.stdout = self._original_stdout
|
|
||||||
sys.stderr = self._original_stderr
|
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
for stream in self.streams:
|
for stream in self.streams:
|
||||||
@@ -26,7 +21,12 @@ class RedirectOutput(object):
|
|||||||
stream.flush()
|
stream.flush()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
pass
|
self._original_stdout = sys.stdout
|
||||||
|
self._original_stderr = sys.stderr
|
||||||
|
sys.stdout = self
|
||||||
|
sys.stderr = self
|
||||||
|
return self
|
||||||
|
|
||||||
def __exit__(self, _type, _value, _traceback):
|
def __exit__(self, _type, _value, _traceback):
|
||||||
pass
|
sys.stdout = self._original_stdout
|
||||||
|
sys.stderr = self._original_stderr
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
import pyomo.environ as pe
|
import pyomo.environ as pe
|
||||||
|
from miplearn.solvers import RedirectOutput
|
||||||
from miplearn.solvers.cplex import CPLEXSolver
|
from miplearn.solvers.cplex import CPLEXSolver
|
||||||
from miplearn.solvers.gurobi import GurobiSolver
|
from miplearn.solvers.gurobi import GurobiSolver
|
||||||
|
|
||||||
@@ -10,6 +12,16 @@ from . import _get_instance
|
|||||||
from ...problems.knapsack import ChallengeA
|
from ...problems.knapsack import ChallengeA
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_output():
|
||||||
|
import sys
|
||||||
|
original_stdout = sys.stdout
|
||||||
|
io = StringIO()
|
||||||
|
with RedirectOutput([io]):
|
||||||
|
print("Hello world")
|
||||||
|
assert sys.stdout == original_stdout
|
||||||
|
assert io.getvalue() == "Hello world\n"
|
||||||
|
|
||||||
|
|
||||||
def test_internal_solver_warm_starts():
|
def test_internal_solver_warm_starts():
|
||||||
for solver in [GurobiSolver(), CPLEXSolver()]:
|
for solver in [GurobiSolver(), CPLEXSolver()]:
|
||||||
instance = _get_instance()
|
instance = _get_instance()
|
||||||
|
|||||||
Reference in New Issue
Block a user