Add user cut callbacks; begin rewrite of UserCutsComponent

This commit is contained in:
2021-04-06 12:46:37 -05:00
parent cfb17551f1
commit 9f2d7439dc
11 changed files with 213 additions and 43 deletions

View File

@@ -17,7 +17,7 @@ from miplearn.solvers.internal import (
LazyCallback,
MIPSolveStats,
)
from miplearn.types import VarIndex, SolverParams, Solution
from miplearn.types import VarIndex, SolverParams, Solution, UserCutCallback
logger = logging.getLogger(__name__)
@@ -153,41 +153,49 @@ class GurobiSolver(InternalSolver):
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
user_cut_cb: UserCutCallback = None,
) -> MIPSolveStats:
self._raise_if_callback()
assert self.model is not None
if iteration_cb is None:
iteration_cb = lambda: False
# Create callback wrapper
def cb_wrapper(cb_model, cb_where):
try:
self.cb_where = cb_where
if cb_where in self.lazy_cb_where:
if lazy_cb is not None and cb_where in self.lazy_cb_where:
lazy_cb(self, self.model)
if user_cut_cb is not None and cb_where == self.gp.GRB.Callback.MIPNODE:
user_cut_cb(self, self.model)
except:
logger.exception("callback error")
finally:
self.cb_where = None
if lazy_cb:
# Configure Gurobi
if lazy_cb is not None:
self.params["LazyConstraints"] = 1
if user_cut_cb is not None:
self.params["PreCrush"] = 1
# Solve problem
total_wallclock_time = 0
total_nodes = 0
streams: List[Any] = [StringIO()]
if tee:
streams += [sys.stdout]
self._apply_params(streams)
if iteration_cb is None:
iteration_cb = lambda: False
while True:
with _RedirectOutput(streams):
if lazy_cb is None:
self.model.optimize()
else:
self.model.optimize(cb_wrapper)
self.model.optimize(cb_wrapper)
total_wallclock_time += self.model.runtime
total_nodes += int(self.model.nodeCount)
should_repeat = iteration_cb()
if not should_repeat:
break
# Fetch results and stats
log = streams[0].getvalue()
ub, lb = None, None
sense = "min" if self.model.modelSense == 1 else "max"
@@ -313,6 +321,11 @@ class GurobiSolver(InternalSolver):
else:
self.model.addConstr(constraint, name=name)
def add_cut(self, cobj: Any) -> None:
assert self.model is not None
assert self.cb_where == self.gp.GRB.Callback.MIPNODE
self.model.cbCut(cobj)
def _clear_warm_start(self) -> None:
for (varname, vardict) in self._all_vars.items():
for (idx, var) in vardict.items():
@@ -421,7 +434,6 @@ class GurobiSolver(InternalSolver):
}
def __setstate__(self, state):
self.params = state["params"]
self.lazy_cb_where = state["lazy_cb_where"]
self.instance = None

View File

@@ -16,6 +16,7 @@ from miplearn.types import (
Solution,
BranchPriorities,
Constraint,
UserCutCallback,
)
logger = logging.getLogger(__name__)
@@ -51,6 +52,7 @@ class InternalSolver(ABC):
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
user_cut_cb: UserCutCallback = None,
) -> MIPSolveStats:
"""
Solves the currently loaded instance. After this method finishes,
@@ -72,6 +74,9 @@ class InternalSolver(ABC):
- Querying if a constraint is satisfied
- Adding a new constraint to the problem
Additional operations may be allowed by specific subclasses.
user_cut_cb: UserCutCallback
This function is called whenever the solver found a new integer-infeasible
solution and needs to generate cutting planes to cut it off.
tee: bool
If true, prints the solver log to the screen.
"""
@@ -146,7 +151,7 @@ class InternalSolver(ABC):
`get_solution`. Missing values indicate variables whose priorities
should not be modified.
"""
raise Exception("Not implemented")
raise NotImplementedError()
@abstractmethod
def get_constraint_ids(self) -> List[str]:
@@ -180,6 +185,13 @@ class InternalSolver(ABC):
"""
pass
def add_cut(self, cobj: Any) -> None:
"""
Adds a cutting plane to the model. This function can only be called from a user
cut callback.
"""
raise NotImplementedError()
@abstractmethod
def extract_constraint(self, cid: str) -> Constraint:
"""

View File

@@ -125,18 +125,22 @@ class LearningSolver:
) -> LearningSolveStats:
# Generate model
# -------------------------------------------------------
if model is None:
with _RedirectOutput([]):
model = instance.to_model()
# Initialize training sample
# -------------------------------------------------------
training_sample = TrainingSample()
instance.training_data += [training_sample]
# Initialize stats
# -------------------------------------------------------
stats: LearningSolveStats = {}
# Initialize internal solver
# -------------------------------------------------------
self.tee = tee
self.internal_solver = self.solver_factory()
assert self.internal_solver is not None
@@ -144,6 +148,7 @@ class LearningSolver:
self.internal_solver.set_instance(instance, model)
# Extract features
# -------------------------------------------------------
FeaturesExtractor(self.internal_solver).extract(instance)
callback_args = (
@@ -156,6 +161,7 @@ class LearningSolver:
)
# Solve root LP relaxation
# -------------------------------------------------------
if self.solve_lp:
logger.debug("Running before_solve_lp callbacks...")
for component in self.components.values():
@@ -172,37 +178,50 @@ class LearningSolver:
for component in self.components.values():
component.after_solve_lp(*callback_args)
# Define wrappers
# Callback wrappers
# -------------------------------------------------------
def iteration_cb_wrapper() -> bool:
should_repeat = False
assert isinstance(instance, Instance)
for comp in self.components.values():
if comp.iteration_cb(self, instance, model):
should_repeat = True
return should_repeat
def lazy_cb_wrapper(
cb_solver: LearningSolver,
cb_solver: InternalSolver,
cb_model: Any,
) -> None:
assert isinstance(instance, Instance)
for comp in self.components.values():
comp.lazy_cb(self, instance, model)
def user_cut_cb_wrapper(
cb_solver: InternalSolver,
cb_model: Any,
) -> None:
for comp in self.components.values():
comp.user_cut_cb(self, instance, model)
lazy_cb = None
if self.use_lazy_cb:
lazy_cb = lazy_cb_wrapper
user_cut_cb = None
if instance.has_user_cuts():
user_cut_cb = user_cut_cb_wrapper
# Before-solve callbacks
# -------------------------------------------------------
logger.debug("Running before_solve_mip callbacks...")
for component in self.components.values():
component.before_solve_mip(*callback_args)
# Solve MIP
# -------------------------------------------------------
logger.info("Solving MIP...")
mip_stats = self.internal_solver.solve(
tee=tee,
iteration_cb=iteration_cb_wrapper,
user_cut_cb=user_cut_cb,
lazy_cb=lazy_cb,
)
stats.update(cast(LearningSolveStats, mip_stats))
@@ -216,17 +235,20 @@ class LearningSolver:
stats["Mode"] = self.mode
# Add some information to training_sample
# -------------------------------------------------------
training_sample.lower_bound = stats["Lower bound"]
training_sample.upper_bound = stats["Upper bound"]
training_sample.mip_log = stats["MIP log"]
training_sample.solution = self.internal_solver.get_solution()
# After-solve callbacks
# -------------------------------------------------------
logger.debug("Calling after_solve_mip callbacks...")
for component in self.components.values():
component.after_solve_mip(*callback_args)
# Write to file, if necessary
# Flush
# -------------------------------------------------------
if not discard_output:
instance.flush()

View File

@@ -23,7 +23,7 @@ from miplearn.solvers.internal import (
LazyCallback,
MIPSolveStats,
)
from miplearn.types import VarIndex, SolverParams, Solution
from miplearn.types import VarIndex, SolverParams, Solution, UserCutCallback
logger = logging.getLogger(__name__)
@@ -81,9 +81,12 @@ class BasePyomoSolver(InternalSolver):
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
user_cut_cb: UserCutCallback = None,
) -> MIPSolveStats:
if lazy_cb is not None:
raise Exception("lazy callback not supported")
raise Exception("lazy callback not currently supported")
if user_cut_cb is not None:
raise Exception("user cut callback not currently supported")
total_wallclock_time = 0
streams: List[Any] = [StringIO()]
if tee:
@@ -318,19 +321,19 @@ class BasePyomoSolver(InternalSolver):
return {}
def set_constraint_sense(self, cid: str, sense: str) -> None:
raise Exception("Not implemented")
raise NotImplementedError()
def extract_constraint(self, cid: str) -> Constraint:
raise Exception("Not implemented")
raise NotImplementedError()
def is_constraint_satisfied(self, cobj: Constraint, tol: float = 1e-6) -> bool:
raise Exception("Not implemented")
raise NotImplementedError()
def is_infeasible(self) -> bool:
return self._termination_condition == TerminationCondition.infeasible
def get_dual(self, cid):
raise Exception("Not implemented")
raise NotImplementedError()
def get_sense(self) -> str:
return self._obj_sense