mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Reorganize callbacks
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from io import StringIO
|
||||
from random import randint
|
||||
from typing import List, Any, Dict, Optional, Hashable
|
||||
@@ -31,6 +32,14 @@ from miplearn.types import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedGurobiConstraint:
|
||||
lhs: Any
|
||||
rhs: float
|
||||
sense: str
|
||||
name: str
|
||||
|
||||
|
||||
class GurobiSolver(InternalSolver):
|
||||
"""
|
||||
An InternalSolver backed by Gurobi's Python API (without Pyomo).
|
||||
@@ -158,6 +167,7 @@ class GurobiSolver(InternalSolver):
|
||||
assert self.model is not None
|
||||
if iteration_cb is None:
|
||||
iteration_cb = lambda: False
|
||||
callback_exceptions = []
|
||||
|
||||
# Create callback wrapper
|
||||
def cb_wrapper(cb_model: Any, cb_where: int) -> None:
|
||||
@@ -167,8 +177,9 @@ class GurobiSolver(InternalSolver):
|
||||
lazy_cb(self, self.model)
|
||||
if user_cut_cb is not None and cb_where == self.gp.GRB.Callback.MIPNODE:
|
||||
user_cut_cb(self, self.model)
|
||||
except:
|
||||
except Exception as e:
|
||||
logger.exception("callback error")
|
||||
callback_exceptions.append(e)
|
||||
finally:
|
||||
self.cb_where = None
|
||||
|
||||
@@ -188,6 +199,8 @@ class GurobiSolver(InternalSolver):
|
||||
while True:
|
||||
with _RedirectOutput(streams):
|
||||
self.model.optimize(cb_wrapper)
|
||||
if len(callback_exceptions) > 0:
|
||||
raise callback_exceptions[0]
|
||||
total_wallclock_time += self.model.runtime
|
||||
total_nodes += int(self.model.nodeCount)
|
||||
should_repeat = iteration_cb()
|
||||
@@ -279,29 +292,26 @@ class GurobiSolver(InternalSolver):
|
||||
)
|
||||
|
||||
@overrides
|
||||
def add_constraint(
|
||||
self,
|
||||
constraint: Any,
|
||||
name: str = "",
|
||||
) -> None:
|
||||
def add_constraint(self, cobj: Any, name: str = "") -> None:
|
||||
assert self.model is not None
|
||||
if type(constraint) is tuple:
|
||||
lhs, sense, rhs, name = constraint
|
||||
if isinstance(cobj, ExtractedGurobiConstraint):
|
||||
if self.cb_where in [
|
||||
self.gp.GRB.Callback.MIPSOL,
|
||||
self.gp.GRB.Callback.MIPNODE,
|
||||
]:
|
||||
self.model.cbLazy(lhs, sense, rhs)
|
||||
self.model.cbLazy(cobj.lhs, cobj.sense, cobj.rhs)
|
||||
else:
|
||||
self.model.addConstr(lhs, sense, rhs, name)
|
||||
self.model.addConstr(cobj.lhs, cobj.sense, cobj.rhs, cobj.name)
|
||||
elif isinstance(cobj, self.gp.TempConstr):
|
||||
if self.cb_where in [
|
||||
self.gp.GRB.Callback.MIPSOL,
|
||||
self.gp.GRB.Callback.MIPNODE,
|
||||
]:
|
||||
self.model.cbLazy(cobj)
|
||||
else:
|
||||
self.model.addConstr(cobj, name=name)
|
||||
else:
|
||||
if self.cb_where in [
|
||||
self.gp.GRB.Callback.MIPSOL,
|
||||
self.gp.GRB.Callback.MIPNODE,
|
||||
]:
|
||||
self.model.cbLazy(constraint)
|
||||
else:
|
||||
self.model.addConstr(constraint, name=name)
|
||||
raise Exception(f"unknown constraint type: {cobj.__class__.__name__}")
|
||||
|
||||
@overrides
|
||||
def add_cut(self, cobj: Any) -> None:
|
||||
@@ -325,21 +335,27 @@ class GurobiSolver(InternalSolver):
|
||||
var.ub = value
|
||||
|
||||
@overrides
|
||||
def extract_constraint(self, cid: str) -> Any:
|
||||
def extract_constraint(self, cid: str) -> ExtractedGurobiConstraint:
|
||||
self._raise_if_callback()
|
||||
assert self.model is not None
|
||||
constr = self.model.getConstrByName(cid)
|
||||
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName)
|
||||
cobj = ExtractedGurobiConstraint(
|
||||
lhs=self.model.getRow(constr),
|
||||
sense=constr.sense,
|
||||
rhs=constr.RHS,
|
||||
name=constr.ConstrName,
|
||||
)
|
||||
self.model.remove(constr)
|
||||
return cobj
|
||||
|
||||
@overrides
|
||||
def is_constraint_satisfied(
|
||||
self,
|
||||
cobj: Any,
|
||||
cobj: ExtractedGurobiConstraint,
|
||||
tol: float = 1e-6,
|
||||
) -> bool:
|
||||
lhs, sense, rhs, name = cobj
|
||||
assert isinstance(cobj, ExtractedGurobiConstraint)
|
||||
lhs, sense, rhs, _ = cobj.lhs, cobj.sense, cobj.rhs, cobj.name
|
||||
if self.cb_where is not None:
|
||||
lhs_value = lhs.getConstant()
|
||||
for i in range(lhs.size()):
|
||||
@@ -433,19 +449,23 @@ class GurobiSolver(InternalSolver):
|
||||
self.model.update()
|
||||
constraints: Dict[str, Constraint] = {}
|
||||
for c in self.model.getConstrs():
|
||||
expr = self.model.getRow(c)
|
||||
lhs: Dict[str, float] = {}
|
||||
for i in range(expr.size()):
|
||||
lhs[expr.getVar(i).varName] = expr.getCoeff(i)
|
||||
constr = self._parse_gurobi_constraint(c)
|
||||
assert c.constrName not in constraints
|
||||
constraints[c.constrName] = Constraint(
|
||||
rhs=c.rhs,
|
||||
lhs=lhs,
|
||||
sense=c.sense,
|
||||
)
|
||||
|
||||
constraints[c.constrName] = constr
|
||||
return constraints
|
||||
|
||||
def _parse_gurobi_constraint(self, c: Any) -> Constraint:
|
||||
assert self.model is not None
|
||||
expr = self.model.getRow(c)
|
||||
lhs: Dict[str, float] = {}
|
||||
for i in range(expr.size()):
|
||||
lhs[expr.getVar(i).varName] = expr.getCoeff(i)
|
||||
return Constraint(rhs=c.rhs, lhs=lhs, sense=c.sense)
|
||||
|
||||
@overrides
|
||||
def are_callbacks_supported(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class GurobiTestInstanceInfeasible(Instance):
|
||||
@overrides
|
||||
@@ -506,5 +526,10 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack):
|
||||
|
||||
@overrides
|
||||
def build_lazy_constraint(self, model: Any, violation: Hashable) -> Any:
|
||||
x = model.getVarByName("x[0]")
|
||||
return x <= 0.0
|
||||
# TODO: Replace by plain constraint
|
||||
return ExtractedGurobiConstraint(
|
||||
lhs=1.0 * model.getVarByName("x[0]"),
|
||||
sense="<",
|
||||
rhs=0.0,
|
||||
name="cut",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user