mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Reformat source code with Black; add pre-commit hooks and CI checks
This commit is contained in:
@@ -11,36 +11,42 @@ from scipy.stats import uniform, randint
|
||||
|
||||
def test_stab():
|
||||
graph = nx.cycle_graph(5)
|
||||
weights = [1., 1., 1., 1., 1.]
|
||||
weights = [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
instance = MaxWeightStableSetInstance(graph, weights)
|
||||
solver = LearningSolver()
|
||||
solver.solve(instance)
|
||||
assert instance.lower_bound == 2.
|
||||
|
||||
|
||||
assert instance.lower_bound == 2.0
|
||||
|
||||
|
||||
def test_stab_generator_fixed_graph():
|
||||
np.random.seed(42)
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
||||
gen = MaxWeightStableSetGenerator(w=uniform(loc=50., scale=10.),
|
||||
n=randint(low=10, high=11),
|
||||
p=uniform(loc=0.05, scale=0.),
|
||||
fix_graph=True)
|
||||
|
||||
gen = MaxWeightStableSetGenerator(
|
||||
w=uniform(loc=50.0, scale=10.0),
|
||||
n=randint(low=10, high=11),
|
||||
p=uniform(loc=0.05, scale=0.0),
|
||||
fix_graph=True,
|
||||
)
|
||||
instances = gen.generate(1_000)
|
||||
weights = np.array([instance.weights for instance in instances])
|
||||
weights_avg_actual = np.round(np.average(weights, axis=0))
|
||||
weights_avg_expected = [55.0] * 10
|
||||
assert list(weights_avg_actual) == weights_avg_expected
|
||||
|
||||
|
||||
|
||||
|
||||
def test_stab_generator_random_graph():
|
||||
np.random.seed(42)
|
||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
||||
gen = MaxWeightStableSetGenerator(w=uniform(loc=50., scale=10.),
|
||||
n=randint(low=30, high=41),
|
||||
p=uniform(loc=0.5, scale=0.),
|
||||
fix_graph=False)
|
||||
|
||||
gen = MaxWeightStableSetGenerator(
|
||||
w=uniform(loc=50.0, scale=10.0),
|
||||
n=randint(low=30, high=41),
|
||||
p=uniform(loc=0.5, scale=0.0),
|
||||
fix_graph=False,
|
||||
)
|
||||
instances = gen.generate(1_000)
|
||||
n_nodes = [instance.graph.number_of_nodes() for instance in instances]
|
||||
n_edges = [instance.graph.number_of_edges() for instance in instances]
|
||||
assert np.round(np.mean(n_nodes)) == 35.
|
||||
assert np.round(np.mean(n_edges), -1) == 300.
|
||||
assert np.round(np.mean(n_nodes)) == 35.0
|
||||
assert np.round(np.mean(n_edges), -1) == 300.0
|
||||
|
||||
Reference in New Issue
Block a user