Make TravelingSalesmanGenerator return data class

master
Alinson S. Xavier 4 years ago
parent 03e5acb11a
commit 87bba1b38e
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -1,6 +1,7 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from dataclasses import dataclass
from typing import List, Tuple, Any, Optional, Dict from typing import List, Tuple, Any, Optional, Dict
import networkx as nx import networkx as nx
@ -17,28 +18,10 @@ from miplearn.solvers.pyomo.base import BasePyomoSolver
from miplearn.types import ConstraintName from miplearn.types import ConstraintName
class ChallengeA: @dataclass
def __init__( class TravelingSalesmanData:
self, n_cities: int
seed: int = 42, distances: np.ndarray
n_training_instances: int = 500,
n_test_instances: int = 50,
) -> None:
np.random.seed(seed)
self.generator = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=350, high=351),
gamma=uniform(loc=0.95, scale=0.1),
fix_cities=True,
round=True,
)
np.random.seed(seed + 1)
self.training_instances = self.generator.generate(n_training_instances)
np.random.seed(seed + 2)
self.test_instances = self.generator.generate(n_test_instances)
class TravelingSalesmanInstance(Instance): class TravelingSalesmanInstance(Instance):
@ -180,8 +163,8 @@ class TravelingSalesmanGenerator:
self.fixed_n = None self.fixed_n = None
self.fixed_cities = None self.fixed_cities = None
def generate(self, n_samples: int) -> List[TravelingSalesmanInstance]: def generate(self, n_samples: int) -> List[TravelingSalesmanData]:
def _sample() -> TravelingSalesmanInstance: def _sample() -> TravelingSalesmanData:
if self.fixed_cities is not None: if self.fixed_cities is not None:
assert self.fixed_n is not None assert self.fixed_n is not None
n, cities = self.fixed_n, self.fixed_cities n, cities = self.fixed_n, self.fixed_cities
@ -191,7 +174,7 @@ class TravelingSalesmanGenerator:
distances = np.tril(distances) + np.triu(distances.T, 1) distances = np.tril(distances) + np.triu(distances.T, 1)
if self.round: if self.round:
distances = distances.round() distances = distances.round()
return TravelingSalesmanInstance(n, distances) return TravelingSalesmanData(n, distances)
return [_sample() for _ in range(n_samples)] return [_sample() for _ in range(n_samples)]

@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features.sample import Sample, MemorySample from miplearn.features.sample import Sample, MemorySample
from miplearn.problems.tsp import TravelingSalesmanGenerator from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance
from miplearn.solvers.learning import LearningSolver from miplearn.solvers.learning import LearningSolver
from miplearn.solvers.tests import assert_equals from miplearn.solvers.tests import assert_equals
@ -108,7 +108,8 @@ def test_usage() -> None:
] ]
) )
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6)) gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
instance = gen.generate(1)[0] data = gen.generate(1)
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
solver.solve(instance) solver.solve(instance)
solver.fit([instance]) solver.fit([instance])
stats = solver.solve(instance) stats = solver.solve(instance)

@ -14,17 +14,17 @@ from miplearn.solvers.tests import assert_equals
def test_generator() -> None: def test_generator() -> None:
instances = TravelingSalesmanGenerator( data = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0), x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0), y=uniform(loc=0.0, scale=1000.0),
n=randint(low=100, high=101), n=randint(low=100, high=101),
gamma=uniform(loc=0.95, scale=0.1), gamma=uniform(loc=0.95, scale=0.1),
fix_cities=True, fix_cities=True,
).generate(100) ).generate(100)
assert len(instances) == 100 assert len(data) == 100
assert instances[0].n_cities == 100 assert data[0].n_cities == 100
assert norm(instances[0].distances - instances[0].distances.T) < 1e-6 assert norm(data[0].distances - data[0].distances.T) < 1e-6
d = [instance.distances[0, 1] for instance in instances] d = [d.distances[0, 1] for d in data]
assert np.std(d) > 0 assert np.std(d) > 0

Loading…
Cancel
Save