mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make MaxWeightStableSetGenerator return data class
This commit is contained in:
@@ -29,8 +29,8 @@ def test_stab_generator_fixed_graph() -> None:
|
||||
p=uniform(loc=0.05, scale=0.0),
|
||||
fix_graph=True,
|
||||
)
|
||||
instances = gen.generate(1_000)
|
||||
weights = np.array([instance.weights for instance in instances])
|
||||
data = gen.generate(1_000)
|
||||
weights = np.array([d.weights for d in data])
|
||||
weights_avg_actual = np.round(np.average(weights, axis=0))
|
||||
weights_avg_expected = [55.0] * 10
|
||||
assert list(weights_avg_actual) == weights_avg_expected
|
||||
@@ -46,8 +46,8 @@ def test_stab_generator_random_graph() -> None:
|
||||
p=uniform(loc=0.5, scale=0.0),
|
||||
fix_graph=False,
|
||||
)
|
||||
instances = gen.generate(1_000)
|
||||
n_nodes = [instance.graph.number_of_nodes() for instance in instances]
|
||||
n_edges = [instance.graph.number_of_edges() for instance in instances]
|
||||
data = gen.generate(1_000)
|
||||
n_nodes = [d.graph.number_of_nodes() for d in data]
|
||||
n_edges = [d.graph.number_of_edges() for d in data]
|
||||
assert np.round(np.mean(n_nodes)) == 35.0
|
||||
assert np.round(np.mean(n_edges), -1) == 300.0
|
||||
|
||||
@@ -7,7 +7,10 @@ import os.path
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.benchmark import BenchmarkRunner
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
||||
from miplearn.problems.stab import (
|
||||
MaxWeightStableSetInstance,
|
||||
MaxWeightStableSetGenerator,
|
||||
)
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
|
||||
|
||||
@@ -15,8 +18,14 @@ def test_benchmark() -> None:
|
||||
for n_jobs in [1, 4]:
|
||||
# Generate training and test instances
|
||||
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
|
||||
train_instances = generator.generate(5)
|
||||
test_instances = generator.generate(3)
|
||||
train_instances = [
|
||||
MaxWeightStableSetInstance(data.graph, data.weights)
|
||||
for data in generator.generate(5)
|
||||
]
|
||||
test_instances = [
|
||||
MaxWeightStableSetInstance(data.graph, data.weights)
|
||||
for data in generator.generate(3)
|
||||
]
|
||||
|
||||
# Solve training instances
|
||||
training_solver = LearningSolver()
|
||||
|
||||
Reference in New Issue
Block a user