mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 09:58:51 -06:00
Improve stable set generator
This commit is contained in:
@@ -4,27 +4,17 @@
|
||||
|
||||
from miplearn import LearningSolver, BenchmarkRunner
|
||||
from miplearn.warmstart import KnnWarmStartPredictor
|
||||
from miplearn.problems.stab import MaxStableSetInstance, MaxStableSetGenerator
|
||||
import networkx as nx
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
||||
from scipy.stats import randint
|
||||
import numpy as np
|
||||
import pyomo.environ as pe
|
||||
import os.path
|
||||
|
||||
|
||||
def test_benchmark():
|
||||
graph = nx.cycle_graph(10)
|
||||
base_weights = np.random.rand(10)
|
||||
|
||||
# Generate training and test instances
|
||||
train_instances = MaxStableSetGenerator(graph=graph,
|
||||
base_weights=base_weights,
|
||||
perturbation_scale=1.0,
|
||||
).generate(5)
|
||||
|
||||
test_instances = MaxStableSetGenerator(graph=graph,
|
||||
base_weights=base_weights,
|
||||
perturbation_scale=1.0,
|
||||
).generate(3)
|
||||
train_instances = MaxWeightStableSetGenerator(n=randint(low=25, high=26)).generate(5)
|
||||
test_instances = MaxWeightStableSetGenerator(n=randint(low=25, high=26)).generate(3)
|
||||
|
||||
# Training phase...
|
||||
training_solver = LearningSolver()
|
||||
|
||||
Reference in New Issue
Block a user