Make TravelingSalesmanGenerator return data class

master
Alinson S. Xavier 4 years ago
parent 03e5acb11a
commit 87bba1b38e
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -1,6 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from dataclasses import dataclass
from typing import List, Tuple, Any, Optional, Dict
import networkx as nx
@ -17,28 +18,10 @@ from miplearn.solvers.pyomo.base import BasePyomoSolver
from miplearn.types import ConstraintName
class ChallengeA:
def __init__(
self,
seed: int = 42,
n_training_instances: int = 500,
n_test_instances: int = 50,
) -> None:
np.random.seed(seed)
self.generator = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=350, high=351),
gamma=uniform(loc=0.95, scale=0.1),
fix_cities=True,
round=True,
)
np.random.seed(seed + 1)
self.training_instances = self.generator.generate(n_training_instances)
np.random.seed(seed + 2)
self.test_instances = self.generator.generate(n_test_instances)
@dataclass
class TravelingSalesmanData:
n_cities: int
distances: np.ndarray
class TravelingSalesmanInstance(Instance):
@ -180,8 +163,8 @@ class TravelingSalesmanGenerator:
self.fixed_n = None
self.fixed_cities = None
def generate(self, n_samples: int) -> List[TravelingSalesmanInstance]:
def _sample() -> TravelingSalesmanInstance:
def generate(self, n_samples: int) -> List[TravelingSalesmanData]:
def _sample() -> TravelingSalesmanData:
if self.fixed_cities is not None:
assert self.fixed_n is not None
n, cities = self.fixed_n, self.fixed_cities
@ -191,7 +174,7 @@ class TravelingSalesmanGenerator:
distances = np.tril(distances) + np.triu(distances.T, 1)
if self.round:
distances = distances.round()
return TravelingSalesmanInstance(n, distances)
return TravelingSalesmanData(n, distances)
return [_sample() for _ in range(n_samples)]

@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features.sample import Sample, MemorySample
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance
from miplearn.solvers.learning import LearningSolver
from miplearn.solvers.tests import assert_equals
@ -108,7 +108,8 @@ def test_usage() -> None:
]
)
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
instance = gen.generate(1)[0]
data = gen.generate(1)
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
solver.solve(instance)
solver.fit([instance])
stats = solver.solve(instance)

@ -14,17 +14,17 @@ from miplearn.solvers.tests import assert_equals
def test_generator() -> None:
instances = TravelingSalesmanGenerator(
data = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=100, high=101),
gamma=uniform(loc=0.95, scale=0.1),
fix_cities=True,
).generate(100)
assert len(instances) == 100
assert instances[0].n_cities == 100
assert norm(instances[0].distances - instances[0].distances.T) < 1e-6
d = [instance.distances[0, 1] for instance in instances]
assert len(data) == 100
assert data[0].n_cities == 100
assert norm(data[0].distances - data[0].distances.T) < 1e-6
d = [d.distances[0, 1] for d in data]
assert np.std(d) > 0

Loading…
Cancel
Save