Make TravelingSalesmanGenerator return data class

This commit is contained in:
2022-02-22 09:23:55 -06:00
parent 03e5acb11a
commit 87bba1b38e
3 changed files with 16 additions and 32 deletions

View File

@@ -1,6 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from dataclasses import dataclass
from typing import List, Tuple, Any, Optional, Dict
import networkx as nx
@@ -17,28 +18,10 @@ from miplearn.solvers.pyomo.base import BasePyomoSolver
from miplearn.types import ConstraintName
class ChallengeA:
def __init__(
self,
seed: int = 42,
n_training_instances: int = 500,
n_test_instances: int = 50,
) -> None:
np.random.seed(seed)
self.generator = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=350, high=351),
gamma=uniform(loc=0.95, scale=0.1),
fix_cities=True,
round=True,
)
np.random.seed(seed + 1)
self.training_instances = self.generator.generate(n_training_instances)
np.random.seed(seed + 2)
self.test_instances = self.generator.generate(n_test_instances)
@dataclass
class TravelingSalesmanData:
n_cities: int
distances: np.ndarray
class TravelingSalesmanInstance(Instance):
@@ -180,8 +163,8 @@ class TravelingSalesmanGenerator:
self.fixed_n = None
self.fixed_cities = None
def generate(self, n_samples: int) -> List[TravelingSalesmanInstance]:
def _sample() -> TravelingSalesmanInstance:
def generate(self, n_samples: int) -> List[TravelingSalesmanData]:
def _sample() -> TravelingSalesmanData:
if self.fixed_cities is not None:
assert self.fixed_n is not None
n, cities = self.fixed_n, self.fixed_cities
@@ -191,7 +174,7 @@ class TravelingSalesmanGenerator:
distances = np.tril(distances) + np.triu(distances.T, 1)
if self.round:
distances = distances.round()
return TravelingSalesmanInstance(n, distances)
return TravelingSalesmanData(n, distances)
return [_sample() for _ in range(n_samples)]