mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Add user cut callbacks; begin rewrite of UserCutsComponent
This commit is contained in:
@@ -17,7 +17,7 @@ from miplearn.solvers.internal import (
|
||||
LazyCallback,
|
||||
MIPSolveStats,
|
||||
)
|
||||
from miplearn.types import VarIndex, SolverParams, Solution
|
||||
from miplearn.types import VarIndex, SolverParams, Solution, UserCutCallback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -153,41 +153,49 @@ class GurobiSolver(InternalSolver):
|
||||
tee: bool = False,
|
||||
iteration_cb: IterationCallback = None,
|
||||
lazy_cb: LazyCallback = None,
|
||||
user_cut_cb: UserCutCallback = None,
|
||||
) -> MIPSolveStats:
|
||||
self._raise_if_callback()
|
||||
assert self.model is not None
|
||||
if iteration_cb is None:
|
||||
iteration_cb = lambda: False
|
||||
|
||||
# Create callback wrapper
|
||||
def cb_wrapper(cb_model, cb_where):
|
||||
try:
|
||||
self.cb_where = cb_where
|
||||
if cb_where in self.lazy_cb_where:
|
||||
if lazy_cb is not None and cb_where in self.lazy_cb_where:
|
||||
lazy_cb(self, self.model)
|
||||
if user_cut_cb is not None and cb_where == self.gp.GRB.Callback.MIPNODE:
|
||||
user_cut_cb(self, self.model)
|
||||
except:
|
||||
logger.exception("callback error")
|
||||
finally:
|
||||
self.cb_where = None
|
||||
|
||||
if lazy_cb:
|
||||
# Configure Gurobi
|
||||
if lazy_cb is not None:
|
||||
self.params["LazyConstraints"] = 1
|
||||
if user_cut_cb is not None:
|
||||
self.params["PreCrush"] = 1
|
||||
|
||||
# Solve problem
|
||||
total_wallclock_time = 0
|
||||
total_nodes = 0
|
||||
streams: List[Any] = [StringIO()]
|
||||
if tee:
|
||||
streams += [sys.stdout]
|
||||
self._apply_params(streams)
|
||||
if iteration_cb is None:
|
||||
iteration_cb = lambda: False
|
||||
while True:
|
||||
with _RedirectOutput(streams):
|
||||
if lazy_cb is None:
|
||||
self.model.optimize()
|
||||
else:
|
||||
self.model.optimize(cb_wrapper)
|
||||
self.model.optimize(cb_wrapper)
|
||||
total_wallclock_time += self.model.runtime
|
||||
total_nodes += int(self.model.nodeCount)
|
||||
should_repeat = iteration_cb()
|
||||
if not should_repeat:
|
||||
break
|
||||
|
||||
# Fetch results and stats
|
||||
log = streams[0].getvalue()
|
||||
ub, lb = None, None
|
||||
sense = "min" if self.model.modelSense == 1 else "max"
|
||||
@@ -313,6 +321,11 @@ class GurobiSolver(InternalSolver):
|
||||
else:
|
||||
self.model.addConstr(constraint, name=name)
|
||||
|
||||
def add_cut(self, cobj: Any) -> None:
|
||||
assert self.model is not None
|
||||
assert self.cb_where == self.gp.GRB.Callback.MIPNODE
|
||||
self.model.cbCut(cobj)
|
||||
|
||||
def _clear_warm_start(self) -> None:
|
||||
for (varname, vardict) in self._all_vars.items():
|
||||
for (idx, var) in vardict.items():
|
||||
@@ -421,7 +434,6 @@ class GurobiSolver(InternalSolver):
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
||||
self.params = state["params"]
|
||||
self.lazy_cb_where = state["lazy_cb_where"]
|
||||
self.instance = None
|
||||
|
||||
Reference in New Issue
Block a user