diff --git a/miplearn/problems/tsp.py b/miplearn/problems/tsp.py index 4261fea..fc3ae75 100644 --- a/miplearn/problems/tsp.py +++ b/miplearn/problems/tsp.py @@ -1,6 +1,7 @@ # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. +from dataclasses import dataclass from typing import List, Tuple, Any, Optional, Dict import networkx as nx @@ -17,28 +18,10 @@ from miplearn.solvers.pyomo.base import BasePyomoSolver from miplearn.types import ConstraintName -class ChallengeA: - def __init__( - self, - seed: int = 42, - n_training_instances: int = 500, - n_test_instances: int = 50, - ) -> None: - np.random.seed(seed) - self.generator = TravelingSalesmanGenerator( - x=uniform(loc=0.0, scale=1000.0), - y=uniform(loc=0.0, scale=1000.0), - n=randint(low=350, high=351), - gamma=uniform(loc=0.95, scale=0.1), - fix_cities=True, - round=True, - ) - - np.random.seed(seed + 1) - self.training_instances = self.generator.generate(n_training_instances) - - np.random.seed(seed + 2) - self.test_instances = self.generator.generate(n_test_instances) +@dataclass +class TravelingSalesmanData: + n_cities: int + distances: np.ndarray class TravelingSalesmanInstance(Instance): @@ -180,8 +163,8 @@ class TravelingSalesmanGenerator: self.fixed_n = None self.fixed_cities = None - def generate(self, n_samples: int) -> List[TravelingSalesmanInstance]: - def _sample() -> TravelingSalesmanInstance: + def generate(self, n_samples: int) -> List[TravelingSalesmanData]: + def _sample() -> TravelingSalesmanData: if self.fixed_cities is not None: assert self.fixed_n is not None n, cities = self.fixed_n, self.fixed_cities @@ -191,7 +174,7 @@ class TravelingSalesmanGenerator: distances = np.tril(distances) + np.triu(distances.T, 1) if self.round: distances = distances.round() - return TravelingSalesmanInstance(n, distances) + return TravelingSalesmanData(n, distances) return [_sample() for _ in range(n_samples)] diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 6acebee..83b1096 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold from miplearn.components import classifier_evaluation_dict from miplearn.components.primal import PrimalSolutionComponent from miplearn.features.sample import Sample, MemorySample -from miplearn.problems.tsp import TravelingSalesmanGenerator +from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance from miplearn.solvers.learning import LearningSolver from miplearn.solvers.tests import assert_equals @@ -108,7 +108,8 @@ def test_usage() -> None: ] ) gen = TravelingSalesmanGenerator(n=randint(low=5, high=6)) - instance = gen.generate(1)[0] + data = gen.generate(1) + instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances) solver.solve(instance) solver.fit([instance]) stats = solver.solve(instance) diff --git a/tests/problems/test_tsp.py b/tests/problems/test_tsp.py index f0216ee..5c6fbc8 100644 --- a/tests/problems/test_tsp.py +++ b/tests/problems/test_tsp.py @@ -14,17 +14,17 @@ from miplearn.solvers.tests import assert_equals def test_generator() -> None: - instances = TravelingSalesmanGenerator( + data = TravelingSalesmanGenerator( x=uniform(loc=0.0, scale=1000.0), y=uniform(loc=0.0, scale=1000.0), n=randint(low=100, high=101), gamma=uniform(loc=0.95, scale=0.1), fix_cities=True, ).generate(100) - assert len(instances) == 100 - assert instances[0].n_cities == 100 - assert norm(instances[0].distances - instances[0].distances.T) < 1e-6 - d = [instance.distances[0, 1] for instance in instances] + assert len(data) == 100 + assert data[0].n_cities == 100 + assert norm(data[0].distances - data[0].distances.T) < 1e-6 + d = [d.distances[0, 1] for d in data] assert np.std(d) > 0