Add user cut callbacks; begin rewrite of UserCutsComponent

This commit is contained in:
2021-04-06 12:46:37 -05:00
parent cfb17551f1
commit 9f2d7439dc
11 changed files with 213 additions and 43 deletions

View File

@@ -198,6 +198,14 @@ class Component:
) -> None:
return
def user_cut_cb(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
) -> None:
return
def evaluate(self, instances: List[Instance]) -> List:
ev = []
for instance in instances:

View File

@@ -0,0 +1,64 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Any, TYPE_CHECKING, Hashable, Set
from miplearn import Component, Instance
import logging
from miplearn.features import Features, TrainingSample
from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
class UserCutsComponentNG(Component):
def __init__(self) -> None:
self.enforced: Set[Hashable] = set()
def before_solve_mip(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
) -> None:
self.enforced.clear()
def after_solve_mip(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
) -> None:
training_data.user_cuts_enforced = set(self.enforced)
def user_cut_cb(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
) -> None:
assert solver.internal_solver is not None
logger.debug("Finding violated user cuts...")
cids = instance.find_violated_user_cuts(model)
logger.debug(f"Found {len(cids)} violated user cuts")
logger.debug("Building violated user cuts...")
for cid in cids:
assert isinstance(cid, Hashable)
cobj = instance.build_user_cut(model, cid)
assert cobj is not None
solver.internal_solver.add_cut(cobj)
self.enforced.add(cid)
if len(cids) > 0:
logger.info(f"Added {len(cids)} violated user cuts")