mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make MaxWeightStableSetGenerator return data class
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from typing import List, Dict
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -14,26 +16,10 @@ from scipy.stats.distributions import rv_frozen
|
|||||||
from miplearn.instance.base import Instance
|
from miplearn.instance.base import Instance
|
||||||
|
|
||||||
|
|
||||||
class ChallengeA:
|
@dataclass
|
||||||
def __init__(
|
class MaxWeightStableSetData:
|
||||||
self,
|
graph: Graph
|
||||||
seed: int = 42,
|
weights: np.ndarray
|
||||||
n_training_instances: int = 500,
|
|
||||||
n_test_instances: int = 50,
|
|
||||||
) -> None:
|
|
||||||
np.random.seed(seed)
|
|
||||||
self.generator = MaxWeightStableSetGenerator(
|
|
||||||
w=uniform(loc=100.0, scale=50.0),
|
|
||||||
n=randint(low=200, high=201),
|
|
||||||
p=uniform(loc=0.05, scale=0.0),
|
|
||||||
fix_graph=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
np.random.seed(seed + 1)
|
|
||||||
self.training_instances = self.generator.generate(n_training_instances)
|
|
||||||
|
|
||||||
np.random.seed(seed + 2)
|
|
||||||
self.test_instances = self.generator.generate(n_test_instances)
|
|
||||||
|
|
||||||
|
|
||||||
class MaxWeightStableSetInstance(Instance):
|
class MaxWeightStableSetInstance(Instance):
|
||||||
@@ -132,14 +118,14 @@ class MaxWeightStableSetGenerator:
|
|||||||
if fix_graph:
|
if fix_graph:
|
||||||
self.graph = self._generate_graph()
|
self.graph = self._generate_graph()
|
||||||
|
|
||||||
def generate(self, n_samples: int) -> List[MaxWeightStableSetInstance]:
|
def generate(self, n_samples: int) -> List[MaxWeightStableSetData]:
|
||||||
def _sample() -> MaxWeightStableSetInstance:
|
def _sample() -> MaxWeightStableSetData:
|
||||||
if self.graph is not None:
|
if self.graph is not None:
|
||||||
graph = self.graph
|
graph = self.graph
|
||||||
else:
|
else:
|
||||||
graph = self._generate_graph()
|
graph = self._generate_graph()
|
||||||
weights = self.w.rvs(graph.number_of_nodes())
|
weights = self.w.rvs(graph.number_of_nodes())
|
||||||
return MaxWeightStableSetInstance(graph, weights)
|
return MaxWeightStableSetData(graph, weights)
|
||||||
|
|
||||||
return [_sample() for _ in range(n_samples)]
|
return [_sample() for _ in range(n_samples)]
|
||||||
|
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ def test_stab_generator_fixed_graph() -> None:
|
|||||||
p=uniform(loc=0.05, scale=0.0),
|
p=uniform(loc=0.05, scale=0.0),
|
||||||
fix_graph=True,
|
fix_graph=True,
|
||||||
)
|
)
|
||||||
instances = gen.generate(1_000)
|
data = gen.generate(1_000)
|
||||||
weights = np.array([instance.weights for instance in instances])
|
weights = np.array([d.weights for d in data])
|
||||||
weights_avg_actual = np.round(np.average(weights, axis=0))
|
weights_avg_actual = np.round(np.average(weights, axis=0))
|
||||||
weights_avg_expected = [55.0] * 10
|
weights_avg_expected = [55.0] * 10
|
||||||
assert list(weights_avg_actual) == weights_avg_expected
|
assert list(weights_avg_actual) == weights_avg_expected
|
||||||
@@ -46,8 +46,8 @@ def test_stab_generator_random_graph() -> None:
|
|||||||
p=uniform(loc=0.5, scale=0.0),
|
p=uniform(loc=0.5, scale=0.0),
|
||||||
fix_graph=False,
|
fix_graph=False,
|
||||||
)
|
)
|
||||||
instances = gen.generate(1_000)
|
data = gen.generate(1_000)
|
||||||
n_nodes = [instance.graph.number_of_nodes() for instance in instances]
|
n_nodes = [d.graph.number_of_nodes() for d in data]
|
||||||
n_edges = [instance.graph.number_of_edges() for instance in instances]
|
n_edges = [d.graph.number_of_edges() for d in data]
|
||||||
assert np.round(np.mean(n_nodes)) == 35.0
|
assert np.round(np.mean(n_nodes)) == 35.0
|
||||||
assert np.round(np.mean(n_edges), -1) == 300.0
|
assert np.round(np.mean(n_edges), -1) == 300.0
|
||||||
|
|||||||
@@ -7,7 +7,10 @@ import os.path
|
|||||||
from scipy.stats import randint
|
from scipy.stats import randint
|
||||||
|
|
||||||
from miplearn.benchmark import BenchmarkRunner
|
from miplearn.benchmark import BenchmarkRunner
|
||||||
from miplearn.problems.stab import MaxWeightStableSetGenerator
|
from miplearn.problems.stab import (
|
||||||
|
MaxWeightStableSetInstance,
|
||||||
|
MaxWeightStableSetGenerator,
|
||||||
|
)
|
||||||
from miplearn.solvers.learning import LearningSolver
|
from miplearn.solvers.learning import LearningSolver
|
||||||
|
|
||||||
|
|
||||||
@@ -15,8 +18,14 @@ def test_benchmark() -> None:
|
|||||||
for n_jobs in [1, 4]:
|
for n_jobs in [1, 4]:
|
||||||
# Generate training and test instances
|
# Generate training and test instances
|
||||||
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
|
generator = MaxWeightStableSetGenerator(n=randint(low=25, high=26))
|
||||||
train_instances = generator.generate(5)
|
train_instances = [
|
||||||
test_instances = generator.generate(3)
|
MaxWeightStableSetInstance(data.graph, data.weights)
|
||||||
|
for data in generator.generate(5)
|
||||||
|
]
|
||||||
|
test_instances = [
|
||||||
|
MaxWeightStableSetInstance(data.graph, data.weights)
|
||||||
|
for data in generator.generate(3)
|
||||||
|
]
|
||||||
|
|
||||||
# Solve training instances
|
# Solve training instances
|
||||||
training_solver = LearningSolver()
|
training_solver = LearningSolver()
|
||||||
|
|||||||
Reference in New Issue
Block a user