mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Capture solver log
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from io import StringIO
|
||||
|
||||
import pyomo.core.kernel.objective
|
||||
import pyomo.environ as pe
|
||||
@@ -40,6 +42,33 @@ def _parallel_solve(instance_idx):
|
||||
}
|
||||
|
||||
|
||||
class RedirectOutput(object):
|
||||
def __init__(self, streams):
|
||||
self.streams = streams
|
||||
self._original_stdout = sys.stdout
|
||||
self._original_stderr = sys.stderr
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
|
||||
def __del__(self):
|
||||
sys.stdout = self._original_stdout
|
||||
sys.stderr = self._original_stderr
|
||||
|
||||
def write(self, data):
|
||||
for stream in self.streams:
|
||||
stream.write(data)
|
||||
|
||||
def flush(self):
|
||||
for stream in self.streams:
|
||||
stream.flush()
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, _type, _value, _traceback):
|
||||
pass
|
||||
|
||||
|
||||
class InternalSolver(ABC):
|
||||
"""
|
||||
The MIP solver used internaly by LearningSolver.
|
||||
@@ -217,13 +246,18 @@ class InternalSolver(ABC):
|
||||
-------
|
||||
dict
|
||||
A dictionary of solver statistics containing the following keys:
|
||||
"Lower bound", "Upper bound", "Wallclock time", "Nodes" and "Sense".
|
||||
"Lower bound", "Upper bound", "Wallclock time", "Nodes", "Sense"
|
||||
and "Log".
|
||||
"""
|
||||
total_wallclock_time = 0
|
||||
streams = [StringIO()]
|
||||
if tee:
|
||||
streams += [sys.stdout]
|
||||
self.instance.found_violations = []
|
||||
while True:
|
||||
logger.debug("Solving MIP...")
|
||||
results = self._pyomo_solver.solve(tee=tee)
|
||||
with RedirectOutput(streams):
|
||||
results = self._pyomo_solver.solve(tee=True)
|
||||
total_wallclock_time += results["Solver"][0]["Wallclock time"]
|
||||
if not hasattr(self.instance, "find_violations"):
|
||||
break
|
||||
@@ -243,6 +277,7 @@ class InternalSolver(ABC):
|
||||
"Wallclock time": total_wallclock_time,
|
||||
"Nodes": 1,
|
||||
"Sense": self._obj_sense,
|
||||
"Log": streams[0].getvalue()
|
||||
}
|
||||
|
||||
|
||||
@@ -280,8 +315,13 @@ class GurobiSolver(InternalSolver):
|
||||
self._pyomo_solver.set_callback(cb)
|
||||
self.instance.found_violations = []
|
||||
print(self._is_warm_start_available)
|
||||
results = self._pyomo_solver.solve(tee=tee,
|
||||
warmstart=self._is_warm_start_available)
|
||||
|
||||
streams = [StringIO()]
|
||||
if tee:
|
||||
streams += [sys.stdout]
|
||||
with RedirectOutput(streams):
|
||||
results = self._pyomo_solver.solve(tee=True,
|
||||
warmstart=self._is_warm_start_available)
|
||||
self._pyomo_solver.set_callback(None)
|
||||
node_count = int(self._pyomo_solver._solver_model.getAttr("NodeCount"))
|
||||
return {
|
||||
@@ -290,6 +330,7 @@ class GurobiSolver(InternalSolver):
|
||||
"Wallclock time": results["Solver"][0]["Wallclock time"],
|
||||
"Nodes": max(1, node_count),
|
||||
"Sense": self._obj_sense,
|
||||
"Log": streams[0].getvalue(),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,8 @@ def test_internal_solver():
|
||||
}
|
||||
})
|
||||
|
||||
stats = solver.solve()
|
||||
stats = solver.solve(tee=True)
|
||||
assert len(stats["Log"]) > 100
|
||||
assert stats["Lower bound"] == 1183.0
|
||||
assert stats["Upper bound"] == 1183.0
|
||||
assert stats["Sense"] == "max"
|
||||
|
||||
Reference in New Issue
Block a user