mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 10:28:52 -06:00
Merge branch 'feature/new-py-api' into feature/docs
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -15,6 +16,12 @@ from scipy.stats.distributions import rv_frozen
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaxWeightStableSetData:
|
||||
graph: Graph
|
||||
weights: np.ndarray
|
||||
|
||||
|
||||
class MaxWeightStableSetInstance(Instance):
|
||||
"""An instance of the Maximum-Weight Stable Set Problem.
|
||||
|
||||
@@ -87,16 +94,30 @@ class MaxWeightStableSetGenerator:
|
||||
if fix_graph:
|
||||
self.graph = self._generate_graph()
|
||||
|
||||
def generate(self, n_samples: int) -> List[MaxWeightStableSetInstance]:
|
||||
def _sample() -> MaxWeightStableSetInstance:
|
||||
def generate(self, n_samples: int) -> List[MaxWeightStableSetData]:
|
||||
def _sample() -> MaxWeightStableSetData:
|
||||
if self.graph is not None:
|
||||
graph = self.graph
|
||||
else:
|
||||
graph = self._generate_graph()
|
||||
weights = self.w.rvs(graph.number_of_nodes())
|
||||
return MaxWeightStableSetInstance(graph, weights)
|
||||
return MaxWeightStableSetData(graph, weights)
|
||||
|
||||
return [_sample() for _ in range(n_samples)]
|
||||
|
||||
def _generate_graph(self) -> Graph:
|
||||
return nx.generators.random_graphs.binomial_graph(self.n.rvs(), self.p.rvs())
|
||||
|
||||
|
||||
def build_stab_model(data: MaxWeightStableSetData) -> pe.ConcreteModel:
|
||||
model = pe.ConcreteModel()
|
||||
nodes = list(data.graph.nodes)
|
||||
model.x = pe.Var(nodes, domain=pe.Binary)
|
||||
model.OBJ = pe.Objective(
|
||||
expr=sum(model.x[v] * data.weights[v] for v in nodes),
|
||||
sense=pe.maximize,
|
||||
)
|
||||
model.clique_eqs = pe.ConstraintList()
|
||||
for clique in nx.find_cliques(data.graph):
|
||||
model.clique_eqs.add(sum(model.x[v] for v in clique) <= 1)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user