From ccfcbe4e64009fd99fd8eadfd6d0338fc2fa9342 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Fri, 31 Jan 2020 20:48:05 -0600 Subject: [PATCH] Unify API for challenges --- miplearn/problems/stab.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/miplearn/problems/stab.py b/miplearn/problems/stab.py index 6da07df..8bfd504 100644 --- a/miplearn/problems/stab.py +++ b/miplearn/problems/stab.py @@ -11,18 +11,24 @@ from scipy.stats import uniform, randint, bernoulli from scipy.stats.distributions import rv_frozen -class MaxWeightStableSetChallengeA: - def __init__(self): +class ChallengeA: + def __init__(self, + seed=42, + n_training_instances=300, + n_test_instances=50, + ): + + np.random.seed(seed) self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=50.), n=randint(low=200, high=201), p=uniform(loc=0.05, scale=0.0), fix_graph=True) - - def get_training_instances(self): - return self.generator.generate(300) - - def get_test_instances(self): - return self.generator.generate(50) + + np.random.seed(seed + 1) + self.training_instances = self.generator.generate(n_training_instances) + + np.random.seed(seed + 2) + self.test_instances = self.generator.generate(n_test_instances) class MaxWeightStableSetGenerator: