Reorganize callbacks

master
Alinson S. Xavier 5 years ago
parent 6ac738beb4
commit 735884151d
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -4,6 +4,7 @@
import logging import logging
import re import re
import sys import sys
from dataclasses import dataclass
from io import StringIO from io import StringIO
from random import randint from random import randint
from typing import List, Any, Dict, Optional, Hashable from typing import List, Any, Dict, Optional, Hashable
@ -31,6 +32,14 @@ from miplearn.types import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass
class ExtractedGurobiConstraint:
lhs: Any
rhs: float
sense: str
name: str
class GurobiSolver(InternalSolver): class GurobiSolver(InternalSolver):
""" """
An InternalSolver backed by Gurobi's Python API (without Pyomo). An InternalSolver backed by Gurobi's Python API (without Pyomo).
@ -158,6 +167,7 @@ class GurobiSolver(InternalSolver):
assert self.model is not None assert self.model is not None
if iteration_cb is None: if iteration_cb is None:
iteration_cb = lambda: False iteration_cb = lambda: False
callback_exceptions = []
# Create callback wrapper # Create callback wrapper
def cb_wrapper(cb_model: Any, cb_where: int) -> None: def cb_wrapper(cb_model: Any, cb_where: int) -> None:
@ -167,8 +177,9 @@ class GurobiSolver(InternalSolver):
lazy_cb(self, self.model) lazy_cb(self, self.model)
if user_cut_cb is not None and cb_where == self.gp.GRB.Callback.MIPNODE: if user_cut_cb is not None and cb_where == self.gp.GRB.Callback.MIPNODE:
user_cut_cb(self, self.model) user_cut_cb(self, self.model)
except: except Exception as e:
logger.exception("callback error") logger.exception("callback error")
callback_exceptions.append(e)
finally: finally:
self.cb_where = None self.cb_where = None
@ -188,6 +199,8 @@ class GurobiSolver(InternalSolver):
while True: while True:
with _RedirectOutput(streams): with _RedirectOutput(streams):
self.model.optimize(cb_wrapper) self.model.optimize(cb_wrapper)
if len(callback_exceptions) > 0:
raise callback_exceptions[0]
total_wallclock_time += self.model.runtime total_wallclock_time += self.model.runtime
total_nodes += int(self.model.nodeCount) total_nodes += int(self.model.nodeCount)
should_repeat = iteration_cb() should_repeat = iteration_cb()
@ -279,29 +292,26 @@ class GurobiSolver(InternalSolver):
) )
@overrides @overrides
def add_constraint( def add_constraint(self, cobj: Any, name: str = "") -> None:
self,
constraint: Any,
name: str = "",
) -> None:
assert self.model is not None assert self.model is not None
if type(constraint) is tuple: if isinstance(cobj, ExtractedGurobiConstraint):
lhs, sense, rhs, name = constraint
if self.cb_where in [ if self.cb_where in [
self.gp.GRB.Callback.MIPSOL, self.gp.GRB.Callback.MIPSOL,
self.gp.GRB.Callback.MIPNODE, self.gp.GRB.Callback.MIPNODE,
]: ]:
self.model.cbLazy(lhs, sense, rhs) self.model.cbLazy(cobj.lhs, cobj.sense, cobj.rhs)
else:
self.model.addConstr(lhs, sense, rhs, name)
else: else:
self.model.addConstr(cobj.lhs, cobj.sense, cobj.rhs, cobj.name)
elif isinstance(cobj, self.gp.TempConstr):
if self.cb_where in [ if self.cb_where in [
self.gp.GRB.Callback.MIPSOL, self.gp.GRB.Callback.MIPSOL,
self.gp.GRB.Callback.MIPNODE, self.gp.GRB.Callback.MIPNODE,
]: ]:
self.model.cbLazy(constraint) self.model.cbLazy(cobj)
else:
self.model.addConstr(cobj, name=name)
else: else:
self.model.addConstr(constraint, name=name) raise Exception(f"unknown constraint type: {cobj.__class__.__name__}")
@overrides @overrides
def add_cut(self, cobj: Any) -> None: def add_cut(self, cobj: Any) -> None:
@ -325,21 +335,27 @@ class GurobiSolver(InternalSolver):
var.ub = value var.ub = value
@overrides @overrides
def extract_constraint(self, cid: str) -> Any: def extract_constraint(self, cid: str) -> ExtractedGurobiConstraint:
self._raise_if_callback() self._raise_if_callback()
assert self.model is not None assert self.model is not None
constr = self.model.getConstrByName(cid) constr = self.model.getConstrByName(cid)
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName) cobj = ExtractedGurobiConstraint(
lhs=self.model.getRow(constr),
sense=constr.sense,
rhs=constr.RHS,
name=constr.ConstrName,
)
self.model.remove(constr) self.model.remove(constr)
return cobj return cobj
@overrides @overrides
def is_constraint_satisfied( def is_constraint_satisfied(
self, self,
cobj: Any, cobj: ExtractedGurobiConstraint,
tol: float = 1e-6, tol: float = 1e-6,
) -> bool: ) -> bool:
lhs, sense, rhs, name = cobj assert isinstance(cobj, ExtractedGurobiConstraint)
lhs, sense, rhs, _ = cobj.lhs, cobj.sense, cobj.rhs, cobj.name
if self.cb_where is not None: if self.cb_where is not None:
lhs_value = lhs.getConstant() lhs_value = lhs.getConstant()
for i in range(lhs.size()): for i in range(lhs.size()):
@ -433,18 +449,22 @@ class GurobiSolver(InternalSolver):
self.model.update() self.model.update()
constraints: Dict[str, Constraint] = {} constraints: Dict[str, Constraint] = {}
for c in self.model.getConstrs(): for c in self.model.getConstrs():
constr = self._parse_gurobi_constraint(c)
assert c.constrName not in constraints
constraints[c.constrName] = constr
return constraints
def _parse_gurobi_constraint(self, c: Any) -> Constraint:
assert self.model is not None
expr = self.model.getRow(c) expr = self.model.getRow(c)
lhs: Dict[str, float] = {} lhs: Dict[str, float] = {}
for i in range(expr.size()): for i in range(expr.size()):
lhs[expr.getVar(i).varName] = expr.getCoeff(i) lhs[expr.getVar(i).varName] = expr.getCoeff(i)
assert c.constrName not in constraints return Constraint(rhs=c.rhs, lhs=lhs, sense=c.sense)
constraints[c.constrName] = Constraint(
rhs=c.rhs,
lhs=lhs,
sense=c.sense,
)
return constraints @overrides
def are_callbacks_supported(self) -> bool:
return True
class GurobiTestInstanceInfeasible(Instance): class GurobiTestInstanceInfeasible(Instance):
@ -506,5 +526,10 @@ class GurobiTestInstanceKnapsack(PyomoTestInstanceKnapsack):
@overrides @overrides
def build_lazy_constraint(self, model: Any, violation: Hashable) -> Any: def build_lazy_constraint(self, model: Any, violation: Hashable) -> Any:
x = model.getVarByName("x[0]") # TODO: Replace by plain constraint
return x <= 0.0 return ExtractedGurobiConstraint(
lhs=1.0 * model.getVarByName("x[0]"),
sense="<",
rhs=0.0,
name="cut",
)

@ -255,3 +255,10 @@ class InternalSolver(ABC):
@abstractmethod @abstractmethod
def build_test_instance_knapsack(self) -> Instance: def build_test_instance_knapsack(self) -> Instance:
pass pass
def are_callbacks_supported(self) -> bool:
"""
Returns True if this solver supports native callbacks, such as lazy constraints
callback or user cuts callback.
"""
return False

@ -98,10 +98,8 @@ class BasePyomoSolver(InternalSolver):
lazy_cb: Optional[LazyCallback] = None, lazy_cb: Optional[LazyCallback] = None,
user_cut_cb: Optional[UserCutCallback] = None, user_cut_cb: Optional[UserCutCallback] = None,
) -> MIPSolveStats: ) -> MIPSolveStats:
if lazy_cb is not None: assert lazy_cb is None, "callbacks are not currently supported"
raise Exception("lazy callback not currently supported") assert user_cut_cb is None, "callbacks are not currently supported"
if user_cut_cb is not None:
raise Exception("user cut callback not currently supported")
total_wallclock_time = 0 total_wallclock_time = 0
streams: List[Any] = [StringIO()] streams: List[Any] = [StringIO()]
if tee: if tee:
@ -413,6 +411,9 @@ class BasePyomoSolver(InternalSolver):
sense=sense, sense=sense,
) )
def are_callbacks_supported(self) -> bool:
return False
class PyomoTestInstanceInfeasible(Instance): class PyomoTestInstanceInfeasible(Instance):
@overrides @overrides

@ -16,6 +16,9 @@ def run_internal_solver_tests(solver: InternalSolver) -> None:
run_basic_usage_tests(solver.clone()) run_basic_usage_tests(solver.clone())
run_warm_start_tests(solver.clone()) run_warm_start_tests(solver.clone())
run_infeasibility_tests(solver.clone()) run_infeasibility_tests(solver.clone())
run_iteration_cb_tests(solver.clone())
if solver.are_callbacks_supported():
run_lazy_cb_tests(solver.clone())
def run_basic_usage_tests(solver: InternalSolver) -> None: def run_basic_usage_tests(solver: InternalSolver) -> None:
@ -193,5 +196,25 @@ def run_iteration_cb_tests(solver: InternalSolver) -> None:
assert_equals(count, 5) assert_equals(count, 5)
def run_lazy_cb_tests(solver: InternalSolver) -> None:
instance = solver.build_test_instance_knapsack()
model = instance.to_model()
lazy_cb_count = 0
def lazy_cb(cb_solver: InternalSolver, cb_model: Any) -> None:
nonlocal lazy_cb_count
lazy_cb_count += 1
cobj = instance.build_lazy_constraint(model, "cut")
if not cb_solver.is_constraint_satisfied(cobj):
cb_solver.add_constraint(cobj)
solver.set_instance(instance, model)
solver.solve(lazy_cb=lazy_cb)
assert lazy_cb_count > 0
solution = solver.get_solution()
assert solution is not None
assert_equals(solution["x[0]"], 0.0)
def assert_equals(left: Any, right: Any) -> None: def assert_equals(left: Any, right: Any) -> None:
assert left == right, f"{left} != {right}" assert left == right, f"{left} != {right}"

@ -1,27 +0,0 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Any
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import InternalSolver
logger = logging.getLogger(__name__)
def test_lazy_cb() -> None:
solver = GurobiSolver()
instance = solver.build_test_instance_knapsack()
model = instance.to_model()
def lazy_cb(cb_solver: InternalSolver, cb_model: Any) -> None:
cobj = (cb_model.getVarByName("x[0]") * 1.0, "<", 0.0, "cut")
if not cb_solver.is_constraint_satisfied(cobj):
cb_solver.add_constraint(cobj)
solver.set_instance(instance, model)
solver.solve(lazy_cb=lazy_cb)
solution = solver.get_solution()
assert solution["x[0]"] == 0.0
Loading…
Cancel
Save