Make MaxWeightStableSetGenerator return data class

This commit is contained in:
2022-02-22 09:16:37 -06:00
parent 1811492557
commit b0d63a0a2d
3 changed files with 27 additions and 32 deletions

View File

@@ -29,8 +29,8 @@ def test_stab_generator_fixed_graph() -> None:
p=uniform(loc=0.05, scale=0.0),
fix_graph=True,
)
instances = gen.generate(1_000)
weights = np.array([instance.weights for instance in instances])
data = gen.generate(1_000)
weights = np.array([d.weights for d in data])
weights_avg_actual = np.round(np.average(weights, axis=0))
weights_avg_expected = [55.0] * 10
assert list(weights_avg_actual) == weights_avg_expected
@@ -46,8 +46,8 @@ def test_stab_generator_random_graph() -> None:
p=uniform(loc=0.5, scale=0.0),
fix_graph=False,
)
instances = gen.generate(1_000)
n_nodes = [instance.graph.number_of_nodes() for instance in instances]
n_edges = [instance.graph.number_of_edges() for instance in instances]
data = gen.generate(1_000)
n_nodes = [d.graph.number_of_nodes() for d in data]
n_edges = [d.graph.number_of_edges() for d in data]
assert np.round(np.mean(n_nodes)) == 35.0
assert np.round(np.mean(n_edges), -1) == 300.0

View File

@@ -7,7 +7,10 @@ import os.path
from scipy.stats import randint
from miplearn.benchmark import BenchmarkRunner
from miplearn.problems.stab import MaxWeightStableSetGenerator
from miplearn.problems.stab import (
MaxWeightStableSetInstance,
MaxWeightStableSetGenerator,
)
from miplearn.solvers.learning import LearningSolver
@@ -15,8 +18,14 @@ def test_benchmark() -> None:
for n_jobs in [1, 4]:
# Generate training and test instances
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
train_instances = generator.generate(5)
test_instances = generator.generate(3)
train_instances = [
MaxWeightStableSetInstance(data.graph, data.weights)
for data in generator.generate(5)
]
test_instances = [
MaxWeightStableSetInstance(data.graph, data.weights)
for data in generator.generate(3)
]
# Solve training instances
training_solver = LearningSolver()