mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Unify API for challenges
This commit is contained in:
@@ -11,18 +11,24 @@ from scipy.stats import uniform, randint, bernoulli
|
|||||||
from scipy.stats.distributions import rv_frozen
|
from scipy.stats.distributions import rv_frozen
|
||||||
|
|
||||||
|
|
||||||
class MaxWeightStableSetChallengeA:
|
class ChallengeA:
|
||||||
def __init__(self):
|
def __init__(self,
|
||||||
|
seed=42,
|
||||||
|
n_training_instances=300,
|
||||||
|
n_test_instances=50,
|
||||||
|
):
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=50.),
|
self.generator = MaxWeightStableSetGenerator(w=uniform(loc=100., scale=50.),
|
||||||
n=randint(low=200, high=201),
|
n=randint(low=200, high=201),
|
||||||
p=uniform(loc=0.05, scale=0.0),
|
p=uniform(loc=0.05, scale=0.0),
|
||||||
fix_graph=True)
|
fix_graph=True)
|
||||||
|
|
||||||
def get_training_instances(self):
|
np.random.seed(seed + 1)
|
||||||
return self.generator.generate(300)
|
self.training_instances = self.generator.generate(n_training_instances)
|
||||||
|
|
||||||
def get_test_instances(self):
|
np.random.seed(seed + 2)
|
||||||
return self.generator.generate(50)
|
self.test_instances = self.generator.generate(n_test_instances)
|
||||||
|
|
||||||
|
|
||||||
class MaxWeightStableSetGenerator:
|
class MaxWeightStableSetGenerator:
|
||||||
|
|||||||
Reference in New Issue
Block a user