diff --git a/miplearn/problems/stab.py b/miplearn/problems/stab.py index 6da07df..8bfd504 100644 --- a/miplearn/problems/stab.py +++ b/miplearn/problems/stab.py @@ -11,18 +11,24 @@ from scipy.stats import uniform, randint, bernoulli from scipy.stats.distributions import rv_frozen -class MaxWeightStableSetChallengeA: - def __init__(self): +class ChallengeA: + def __init__(self, + seed=42, + n_training_instances=300, + n_test_instances=50, + ): + + np.random.seed(seed) self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=50.), n=randint(low=200, high=201), p=uniform(loc=0.05, scale=0.0), fix_graph=True) - - def get_training_instances(self): - return self.generator.generate(300) - - def get_test_instances(self): - return self.generator.generate(50) + + np.random.seed(seed + 1) + self.training_instances = self.generator.generate(n_training_instances) + + np.random.seed(seed + 2) + self.test_instances = self.generator.generate(n_test_instances) class MaxWeightStableSetGenerator: