parent
cfb17551f1
commit
9f2d7439dc
@ -0,0 +1,64 @@
|
|||||||
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
from typing import Any, TYPE_CHECKING, Hashable, Set
|
||||||
|
|
||||||
|
from miplearn import Component, Instance
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from miplearn.features import Features, TrainingSample
|
||||||
|
from miplearn.types import LearningSolveStats
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from miplearn.solvers.learning import LearningSolver
|
||||||
|
|
||||||
|
|
||||||
|
class UserCutsComponentNG(Component):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.enforced: Set[Hashable] = set()
|
||||||
|
|
||||||
|
def before_solve_mip(
|
||||||
|
self,
|
||||||
|
solver: "LearningSolver",
|
||||||
|
instance: Instance,
|
||||||
|
model: Any,
|
||||||
|
stats: LearningSolveStats,
|
||||||
|
features: Features,
|
||||||
|
training_data: TrainingSample,
|
||||||
|
) -> None:
|
||||||
|
self.enforced.clear()
|
||||||
|
|
||||||
|
def after_solve_mip(
|
||||||
|
self,
|
||||||
|
solver: "LearningSolver",
|
||||||
|
instance: Instance,
|
||||||
|
model: Any,
|
||||||
|
stats: LearningSolveStats,
|
||||||
|
features: Features,
|
||||||
|
training_data: TrainingSample,
|
||||||
|
) -> None:
|
||||||
|
training_data.user_cuts_enforced = set(self.enforced)
|
||||||
|
|
||||||
|
def user_cut_cb(
|
||||||
|
self,
|
||||||
|
solver: "LearningSolver",
|
||||||
|
instance: Instance,
|
||||||
|
model: Any,
|
||||||
|
) -> None:
|
||||||
|
assert solver.internal_solver is not None
|
||||||
|
logger.debug("Finding violated user cuts...")
|
||||||
|
cids = instance.find_violated_user_cuts(model)
|
||||||
|
logger.debug(f"Found {len(cids)} violated user cuts")
|
||||||
|
logger.debug("Building violated user cuts...")
|
||||||
|
for cid in cids:
|
||||||
|
assert isinstance(cid, Hashable)
|
||||||
|
cobj = instance.build_user_cut(model, cid)
|
||||||
|
assert cobj is not None
|
||||||
|
solver.internal_solver.add_cut(cobj)
|
||||||
|
self.enforced.add(cid)
|
||||||
|
if len(cids) > 0:
|
||||||
|
logger.info(f"Added {len(cids)} violated user cuts")
|
@ -1,31 +1,77 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, FrozenSet, Hashable
|
||||||
|
|
||||||
|
import gurobipy as gp
|
||||||
|
import networkx as nx
|
||||||
import pytest
|
import pytest
|
||||||
|
from gurobipy import GRB
|
||||||
from networkx import Graph
|
from networkx import Graph
|
||||||
import networkx as nx
|
|
||||||
from scipy.stats import randint
|
|
||||||
|
|
||||||
from miplearn import Instance
|
from miplearn import Instance, LearningSolver, GurobiSolver
|
||||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
from miplearn.components.user_cuts import UserCutsComponentNG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GurobiStableSetProblem(Instance):
|
class GurobiStableSetProblem(Instance):
|
||||||
def __init__(self, graph: Graph) -> None:
|
def __init__(self, graph: Graph) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
|
self.nodes = list(self.graph.nodes)
|
||||||
|
|
||||||
def to_model(self) -> Any:
|
def to_model(self) -> Any:
|
||||||
pass
|
model = gp.Model()
|
||||||
|
x = [model.addVar(vtype=GRB.BINARY) for _ in range(len(self.nodes))]
|
||||||
|
model.setObjective(gp.quicksum(x), GRB.MAXIMIZE)
|
||||||
|
for e in list(self.graph.edges):
|
||||||
|
model.addConstr(x[e[0]] + x[e[1]] <= 1)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def has_user_cuts(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def find_violated_user_cuts(self, model):
|
||||||
|
assert isinstance(model, gp.Model)
|
||||||
|
vals = model.cbGetNodeRel(model.getVars())
|
||||||
|
violations = []
|
||||||
|
for clique in nx.find_cliques(self.graph):
|
||||||
|
lhs = sum(vals[i] for i in clique)
|
||||||
|
if lhs > 1:
|
||||||
|
violations += [frozenset(clique)]
|
||||||
|
return violations
|
||||||
|
|
||||||
|
def build_user_cut(self, model: Any, violation: Hashable) -> Any:
|
||||||
|
assert isinstance(violation, FrozenSet)
|
||||||
|
x = model.getVars()
|
||||||
|
cut = gp.quicksum([x[i] for i in violation]) <= 1
|
||||||
|
return cut
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def instance() -> Instance:
|
def stab_instance() -> Instance:
|
||||||
graph = nx.generators.random_graphs.binomial_graph(50, 0.5)
|
graph = nx.generators.random_graphs.binomial_graph(50, 0.50, seed=42)
|
||||||
return GurobiStableSetProblem(graph)
|
return GurobiStableSetProblem(graph)
|
||||||
|
|
||||||
|
|
||||||
def test_usage(instance: Instance) -> None:
|
@pytest.fixture
|
||||||
pass
|
def solver() -> LearningSolver:
|
||||||
|
return LearningSolver(
|
||||||
|
solver=lambda: GurobiSolver(),
|
||||||
|
components=[
|
||||||
|
UserCutsComponentNG(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage(
|
||||||
|
stab_instance: Instance,
|
||||||
|
solver: LearningSolver,
|
||||||
|
) -> None:
|
||||||
|
solver.solve(stab_instance)
|
||||||
|
sample = stab_instance.training_data[0]
|
||||||
|
assert sample.user_cuts_enforced is not None
|
||||||
|
assert len(sample.user_cuts_enforced) > 0
|
||||||
|
Loading…
Reference in new issue