diff --git a/miplearn/problems/stab.py b/miplearn/problems/stab.py index a64fb3c..97e5559 100644 --- a/miplearn/problems/stab.py +++ b/miplearn/problems/stab.py @@ -1,7 +1,9 @@ # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from typing import List, Dict + +from dataclasses import dataclass +from typing import List import networkx as nx import numpy as np @@ -14,26 +16,10 @@ from scipy.stats.distributions import rv_frozen from miplearn.instance.base import Instance -class ChallengeA: - def __init__( - self, - seed: int = 42, - n_training_instances: int = 500, - n_test_instances: int = 50, - ) -> None: - np.random.seed(seed) - self.generator = MaxWeightStableSetGenerator( - w=uniform(loc=100.0, scale=50.0), - n=randint(low=200, high=201), - p=uniform(loc=0.05, scale=0.0), - fix_graph=True, - ) - - np.random.seed(seed + 1) - self.training_instances = self.generator.generate(n_training_instances) - - np.random.seed(seed + 2) - self.test_instances = self.generator.generate(n_test_instances) +@dataclass +class MaxWeightStableSetData: + graph: Graph + weights: np.ndarray class MaxWeightStableSetInstance(Instance): @@ -132,14 +118,14 @@ class MaxWeightStableSetGenerator: if fix_graph: self.graph = self._generate_graph() - def generate(self, n_samples: int) -> List[MaxWeightStableSetInstance]: - def _sample() -> MaxWeightStableSetInstance: + def generate(self, n_samples: int) -> List[MaxWeightStableSetData]: + def _sample() -> MaxWeightStableSetData: if self.graph is not None: graph = self.graph else: graph = self._generate_graph() weights = self.w.rvs(graph.number_of_nodes()) - return MaxWeightStableSetInstance(graph, weights) + return MaxWeightStableSetData(graph, weights) return [_sample() for _ in range(n_samples)] diff --git a/tests/problems/test_stab.py b/tests/problems/test_stab.py index df40d33..e04a5e0 100644 --- a/tests/problems/test_stab.py +++ b/tests/problems/test_stab.py @@ -29,8 +29,8 @@ def test_stab_generator_fixed_graph() -> None: p=uniform(loc=0.05, scale=0.0), fix_graph=True, ) - instances = gen.generate(1_000) - weights = np.array([instance.weights for instance in instances]) + data = gen.generate(1_000) + weights = np.array([d.weights for d in data]) weights_avg_actual = np.round(np.average(weights, axis=0)) weights_avg_expected = [55.0] * 10 assert list(weights_avg_actual) == weights_avg_expected @@ -46,8 +46,8 @@ def test_stab_generator_random_graph() -> None: p=uniform(loc=0.5, scale=0.0), fix_graph=False, ) - instances = gen.generate(1_000) - n_nodes = [instance.graph.number_of_nodes() for instance in instances] - n_edges = [instance.graph.number_of_edges() for instance in instances] + data = gen.generate(1_000) + n_nodes = [d.graph.number_of_nodes() for d in data] + n_edges = [d.graph.number_of_edges() for d in data] assert np.round(np.mean(n_nodes)) == 35.0 assert np.round(np.mean(n_edges), -1) == 300.0 diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index ad72bf4..da1c096 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -7,7 +7,10 @@ import os.path from scipy.stats import randint from miplearn.benchmark import BenchmarkRunner -from miplearn.problems.stab import MaxWeightStableSetGenerator +from miplearn.problems.stab import ( + MaxWeightStableSetInstance, + MaxWeightStableSetGenerator, +) from miplearn.solvers.learning import LearningSolver @@ -15,8 +18,14 @@ def test_benchmark() -> None: for n_jobs in [1, 4]: # Generate training and test instances generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26)) - train_instances = generator.generate(5) - test_instances = generator.generate(3) + train_instances = [ + MaxWeightStableSetInstance(data.graph, data.weights) + for data in generator.generate(5) + ] + test_instances = [ + MaxWeightStableSetInstance(data.graph, data.weights) + for data in generator.generate(3) + ] # Solve training instances training_solver = LearningSolver()