mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make MaxWeightStableSetGenerator return data class
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import List, Dict
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -14,26 +16,10 @@ from scipy.stats.distributions import rv_frozen
|
||||
from miplearn.instance.base import Instance
|
||||
|
||||
|
||||
class ChallengeA:
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42,
|
||||
n_training_instances: int = 500,
|
||||
n_test_instances: int = 50,
|
||||
) -> None:
|
||||
np.random.seed(seed)
|
||||
self.generator = MaxWeightStableSetGenerator(
|
||||
w=uniform(loc=100.0, scale=50.0),
|
||||
n=randint(low=200, high=201),
|
||||
p=uniform(loc=0.05, scale=0.0),
|
||||
fix_graph=True,
|
||||
)
|
||||
|
||||
np.random.seed(seed + 1)
|
||||
self.training_instances = self.generator.generate(n_training_instances)
|
||||
|
||||
np.random.seed(seed + 2)
|
||||
self.test_instances = self.generator.generate(n_test_instances)
|
||||
@dataclass
|
||||
class MaxWeightStableSetData:
|
||||
graph: Graph
|
||||
weights: np.ndarray
|
||||
|
||||
|
||||
class MaxWeightStableSetInstance(Instance):
|
||||
@@ -132,14 +118,14 @@ class MaxWeightStableSetGenerator:
|
||||
if fix_graph:
|
||||
self.graph = self._generate_graph()
|
||||
|
||||
def generate(self, n_samples: int) -> List[MaxWeightStableSetInstance]:
|
||||
def _sample() -> MaxWeightStableSetInstance:
|
||||
def generate(self, n_samples: int) -> List[MaxWeightStableSetData]:
|
||||
def _sample() -> MaxWeightStableSetData:
|
||||
if self.graph is not None:
|
||||
graph = self.graph
|
||||
else:
|
||||
graph = self._generate_graph()
|
||||
weights = self.w.rvs(graph.number_of_nodes())
|
||||
return MaxWeightStableSetInstance(graph, weights)
|
||||
return MaxWeightStableSetData(graph, weights)
|
||||
|
||||
return [_sample() for _ in range(n_samples)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user