Make MaxWeightStableSetGenerator return data class

master
Alinson S. Xavier 4 years ago
parent 1811492557
commit b0d63a0a2d
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -1,7 +1,9 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from typing import List, Dict
from dataclasses import dataclass
from typing import List
import networkx as nx import networkx as nx
import numpy as np import numpy as np
@ -14,26 +16,10 @@ from scipy.stats.distributions import rv_frozen
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
class ChallengeA: @dataclass
def __init__( class MaxWeightStableSetData:
self, graph: Graph
seed: int = 42, weights: np.ndarray
n_training_instances: int = 500,
n_test_instances: int = 50,
) -> None:
np.random.seed(seed)
self.generator = MaxWeightStableSetGenerator(
w=uniform(loc=100.0, scale=50.0),
n=randint(low=200, high=201),
p=uniform(loc=0.05, scale=0.0),
fix_graph=True,
)
np.random.seed(seed + 1)
self.training_instances = self.generator.generate(n_training_instances)
np.random.seed(seed + 2)
self.test_instances = self.generator.generate(n_test_instances)
class MaxWeightStableSetInstance(Instance): class MaxWeightStableSetInstance(Instance):
@ -132,14 +118,14 @@ class MaxWeightStableSetGenerator:
if fix_graph: if fix_graph:
self.graph = self._generate_graph() self.graph = self._generate_graph()
def generate(self, n_samples: int) -> List[MaxWeightStableSetInstance]: def generate(self, n_samples: int) -> List[MaxWeightStableSetData]:
def _sample() -> MaxWeightStableSetInstance: def _sample() -> MaxWeightStableSetData:
if self.graph is not None: if self.graph is not None:
graph = self.graph graph = self.graph
else: else:
graph = self._generate_graph() graph = self._generate_graph()
weights = self.w.rvs(graph.number_of_nodes()) weights = self.w.rvs(graph.number_of_nodes())
return MaxWeightStableSetInstance(graph, weights) return MaxWeightStableSetData(graph, weights)
return [_sample() for _ in range(n_samples)] return [_sample() for _ in range(n_samples)]

@ -29,8 +29,8 @@ def test_stab_generator_fixed_graph() -> None:
p=uniform(loc=0.05, scale=0.0), p=uniform(loc=0.05, scale=0.0),
fix_graph=True, fix_graph=True,
) )
instances = gen.generate(1_000) data = gen.generate(1_000)
weights = np.array([instance.weights for instance in instances]) weights = np.array([d.weights for d in data])
weights_avg_actual = np.round(np.average(weights, axis=0)) weights_avg_actual = np.round(np.average(weights, axis=0))
weights_avg_expected = [55.0] * 10 weights_avg_expected = [55.0] * 10
assert list(weights_avg_actual) == weights_avg_expected assert list(weights_avg_actual) == weights_avg_expected
@ -46,8 +46,8 @@ def test_stab_generator_random_graph() -> None:
p=uniform(loc=0.5, scale=0.0), p=uniform(loc=0.5, scale=0.0),
fix_graph=False, fix_graph=False,
) )
instances = gen.generate(1_000) data = gen.generate(1_000)
n_nodes = [instance.graph.number_of_nodes() for instance in instances] n_nodes = [d.graph.number_of_nodes() for d in data]
n_edges = [instance.graph.number_of_edges() for instance in instances] n_edges = [d.graph.number_of_edges() for d in data]
assert np.round(np.mean(n_nodes)) == 35.0 assert np.round(np.mean(n_nodes)) == 35.0
assert np.round(np.mean(n_edges), -1) == 300.0 assert np.round(np.mean(n_edges), -1) == 300.0

@ -7,7 +7,10 @@ import os.path
from scipy.stats import randint from scipy.stats import randint
from miplearn.benchmark import BenchmarkRunner from miplearn.benchmark import BenchmarkRunner
from miplearn.problems.stab import MaxWeightStableSetGenerator from miplearn.problems.stab import (
MaxWeightStableSetInstance,
MaxWeightStableSetGenerator,
)
from miplearn.solvers.learning import LearningSolver from miplearn.solvers.learning import LearningSolver
@ -15,8 +18,14 @@ def test_benchmark() -> None:
for n_jobs in [1, 4]: for n_jobs in [1, 4]:
# Generate training and test instances # Generate training and test instances
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26)) generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
train_instances = generator.generate(5) train_instances = [
test_instances = generator.generate(3) MaxWeightStableSetInstance(data.graph, data.weights)
for data in generator.generate(5)
]
test_instances = [
MaxWeightStableSetInstance(data.graph, data.weights)
for data in generator.generate(3)
]
# Solve training instances # Solve training instances
training_solver = LearningSolver() training_solver = LearningSolver()

Loading…
Cancel
Save