You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
MIPLearn/miplearn/components/user_cuts.py

65 lines
1.9 KiB

# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import Any, TYPE_CHECKING, Hashable, Set
from miplearn import Component, Instance
import logging
from miplearn.features import Features, TrainingSample
from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
class UserCutsComponentNG(Component):
def __init__(self) -> None:
self.enforced: Set[Hashable] = set()
def before_solve_mip(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
) -> None:
self.enforced.clear()
def after_solve_mip(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
stats: LearningSolveStats,
features: Features,
training_data: TrainingSample,
) -> None:
training_data.user_cuts_enforced = set(self.enforced)
def user_cut_cb(
self,
solver: "LearningSolver",
instance: Instance,
model: Any,
) -> None:
assert solver.internal_solver is not None
logger.debug("Finding violated user cuts...")
cids = instance.find_violated_user_cuts(model)
logger.debug(f"Found {len(cids)} violated user cuts")
logger.debug("Building violated user cuts...")
for cid in cids:
assert isinstance(cid, Hashable)
cobj = instance.build_user_cut(model, cid)
assert cobj is not None
solver.internal_solver.add_cut(cobj)
self.enforced.add(cid)
if len(cids) > 0:
logger.info(f"Added {len(cids)} violated user cuts")