parent
cfb17551f1
commit
9f2d7439dc
@ -0,0 +1,64 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import Any, TYPE_CHECKING, Hashable, Set
|
||||
|
||||
from miplearn import Component, Instance
|
||||
|
||||
import logging
|
||||
|
||||
from miplearn.features import Features, TrainingSample
|
||||
from miplearn.types import LearningSolveStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
|
||||
|
||||
class UserCutsComponentNG(Component):
|
||||
def __init__(self) -> None:
|
||||
self.enforced: Set[Hashable] = set()
|
||||
|
||||
def before_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
) -> None:
|
||||
self.enforced.clear()
|
||||
|
||||
def after_solve_mip(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
stats: LearningSolveStats,
|
||||
features: Features,
|
||||
training_data: TrainingSample,
|
||||
) -> None:
|
||||
training_data.user_cuts_enforced = set(self.enforced)
|
||||
|
||||
def user_cut_cb(
|
||||
self,
|
||||
solver: "LearningSolver",
|
||||
instance: Instance,
|
||||
model: Any,
|
||||
) -> None:
|
||||
assert solver.internal_solver is not None
|
||||
logger.debug("Finding violated user cuts...")
|
||||
cids = instance.find_violated_user_cuts(model)
|
||||
logger.debug(f"Found {len(cids)} violated user cuts")
|
||||
logger.debug("Building violated user cuts...")
|
||||
for cid in cids:
|
||||
assert isinstance(cid, Hashable)
|
||||
cobj = instance.build_user_cut(model, cid)
|
||||
assert cobj is not None
|
||||
solver.internal_solver.add_cut(cobj)
|
||||
self.enforced.add(cid)
|
||||
if len(cids) > 0:
|
||||
logger.info(f"Added {len(cids)} violated user cuts")
|
@ -1,31 +1,77 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import Any, List
|
||||
|
||||
import logging
|
||||
from typing import Any, FrozenSet, Hashable
|
||||
|
||||
import gurobipy as gp
|
||||
import networkx as nx
|
||||
import pytest
|
||||
from gurobipy import GRB
|
||||
from networkx import Graph
|
||||
import networkx as nx
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn import Instance
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
||||
from miplearn import Instance, LearningSolver, GurobiSolver
|
||||
from miplearn.components.user_cuts import UserCutsComponentNG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GurobiStableSetProblem(Instance):
|
||||
def __init__(self, graph: Graph) -> None:
|
||||
super().__init__()
|
||||
self.graph = graph
|
||||
self.nodes = list(self.graph.nodes)
|
||||
|
||||
def to_model(self) -> Any:
|
||||
pass
|
||||
model = gp.Model()
|
||||
x = [model.addVar(vtype=GRB.BINARY) for _ in range(len(self.nodes))]
|
||||
model.setObjective(gp.quicksum(x), GRB.MAXIMIZE)
|
||||
for e in list(self.graph.edges):
|
||||
model.addConstr(x[e[0]] + x[e[1]] <= 1)
|
||||
return model
|
||||
|
||||
def has_user_cuts(self) -> bool:
|
||||
return True
|
||||
|
||||
def find_violated_user_cuts(self, model):
|
||||
assert isinstance(model, gp.Model)
|
||||
vals = model.cbGetNodeRel(model.getVars())
|
||||
violations = []
|
||||
for clique in nx.find_cliques(self.graph):
|
||||
lhs = sum(vals[i] for i in clique)
|
||||
if lhs > 1:
|
||||
violations += [frozenset(clique)]
|
||||
return violations
|
||||
|
||||
def build_user_cut(self, model: Any, violation: Hashable) -> Any:
|
||||
assert isinstance(violation, FrozenSet)
|
||||
x = model.getVars()
|
||||
cut = gp.quicksum([x[i] for i in violation]) <= 1
|
||||
return cut
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
graph = nx.generators.random_graphs.binomial_graph(50, 0.5)
|
||||
def stab_instance() -> Instance:
|
||||
graph = nx.generators.random_graphs.binomial_graph(50, 0.50, seed=42)
|
||||
return GurobiStableSetProblem(graph)
|
||||
|
||||
|
||||
def test_usage(instance: Instance) -> None:
|
||||
pass
|
||||
@pytest.fixture
|
||||
def solver() -> LearningSolver:
|
||||
return LearningSolver(
|
||||
solver=lambda: GurobiSolver(),
|
||||
components=[
|
||||
UserCutsComponentNG(),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_usage(
|
||||
stab_instance: Instance,
|
||||
solver: LearningSolver,
|
||||
) -> None:
|
||||
solver.solve(stab_instance)
|
||||
sample = stab_instance.training_data[0]
|
||||
assert sample.user_cuts_enforced is not None
|
||||
assert len(sample.user_cuts_enforced) > 0
|
||||
|
Loading…
Reference in new issue