mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Unify API for challenges
This commit is contained in:
@@ -11,18 +11,24 @@ from scipy.stats import uniform, randint, bernoulli
|
||||
from scipy.stats.distributions import rv_frozen
|
||||
|
||||
|
||||
class MaxWeightStableSetChallengeA:
|
||||
def __init__(self):
|
||||
class ChallengeA:
|
||||
def __init__(self,
|
||||
seed=42,
|
||||
n_training_instances=300,
|
||||
n_test_instances=50,
|
||||
):
|
||||
|
||||
np.random.seed(seed)
|
||||
self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=50.),
|
||||
n=randint(low=200, high=201),
|
||||
p=uniform(loc=0.05, scale=0.0),
|
||||
fix_graph=True)
|
||||
|
||||
def get_training_instances(self):
|
||||
return self.generator.generate(300)
|
||||
np.random.seed(seed + 1)
|
||||
self.training_instances = self.generator.generate(n_training_instances)
|
||||
|
||||
def get_test_instances(self):
|
||||
return self.generator.generate(50)
|
||||
np.random.seed(seed + 2)
|
||||
self.test_instances = self.generator.generate(n_test_instances)
|
||||
|
||||
|
||||
class MaxWeightStableSetGenerator:
|
||||
|
||||
Reference in New Issue
Block a user