mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make TravelingSalesmanGenerator return data class
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Any, Optional, Dict
|
||||
|
||||
import networkx as nx
|
||||
@@ -17,28 +18,10 @@ from miplearn.solvers.pyomo.base import BasePyomoSolver
|
||||
from miplearn.types import ConstraintName
|
||||
|
||||
|
||||
class ChallengeA:
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42,
|
||||
n_training_instances: int = 500,
|
||||
n_test_instances: int = 50,
|
||||
) -> None:
|
||||
np.random.seed(seed)
|
||||
self.generator = TravelingSalesmanGenerator(
|
||||
x=uniform(loc=0.0, scale=1000.0),
|
||||
y=uniform(loc=0.0, scale=1000.0),
|
||||
n=randint(low=350, high=351),
|
||||
gamma=uniform(loc=0.95, scale=0.1),
|
||||
fix_cities=True,
|
||||
round=True,
|
||||
)
|
||||
|
||||
np.random.seed(seed + 1)
|
||||
self.training_instances = self.generator.generate(n_training_instances)
|
||||
|
||||
np.random.seed(seed + 2)
|
||||
self.test_instances = self.generator.generate(n_test_instances)
|
||||
@dataclass
|
||||
class TravelingSalesmanData:
|
||||
n_cities: int
|
||||
distances: np.ndarray
|
||||
|
||||
|
||||
class TravelingSalesmanInstance(Instance):
|
||||
@@ -180,8 +163,8 @@ class TravelingSalesmanGenerator:
|
||||
self.fixed_n = None
|
||||
self.fixed_cities = None
|
||||
|
||||
def generate(self, n_samples: int) -> List[TravelingSalesmanInstance]:
|
||||
def _sample() -> TravelingSalesmanInstance:
|
||||
def generate(self, n_samples: int) -> List[TravelingSalesmanData]:
|
||||
def _sample() -> TravelingSalesmanData:
|
||||
if self.fixed_cities is not None:
|
||||
assert self.fixed_n is not None
|
||||
n, cities = self.fixed_n, self.fixed_cities
|
||||
@@ -191,7 +174,7 @@ class TravelingSalesmanGenerator:
|
||||
distances = np.tril(distances) + np.triu(distances.T, 1)
|
||||
if self.round:
|
||||
distances = distances.round()
|
||||
return TravelingSalesmanInstance(n, distances)
|
||||
return TravelingSalesmanData(n, distances)
|
||||
|
||||
return [_sample() for _ in range(n_samples)]
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.primal import PrimalSolutionComponent
|
||||
from miplearn.features.sample import Sample, MemorySample
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance
|
||||
from miplearn.solvers.learning import LearningSolver
|
||||
from miplearn.solvers.tests import assert_equals
|
||||
|
||||
@@ -108,7 +108,8 @@ def test_usage() -> None:
|
||||
]
|
||||
)
|
||||
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||
instance = gen.generate(1)[0]
|
||||
data = gen.generate(1)
|
||||
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
|
||||
solver.solve(instance)
|
||||
solver.fit([instance])
|
||||
stats = solver.solve(instance)
|
||||
|
||||
@@ -14,17 +14,17 @@ from miplearn.solvers.tests import assert_equals
|
||||
|
||||
|
||||
def test_generator() -> None:
|
||||
instances = TravelingSalesmanGenerator(
|
||||
data = TravelingSalesmanGenerator(
|
||||
x=uniform(loc=0.0, scale=1000.0),
|
||||
y=uniform(loc=0.0, scale=1000.0),
|
||||
n=randint(low=100, high=101),
|
||||
gamma=uniform(loc=0.95, scale=0.1),
|
||||
fix_cities=True,
|
||||
).generate(100)
|
||||
assert len(instances) == 100
|
||||
assert instances[0].n_cities == 100
|
||||
assert norm(instances[0].distances - instances[0].distances.T) < 1e-6
|
||||
d = [instance.distances[0, 1] for instance in instances]
|
||||
assert len(data) == 100
|
||||
assert data[0].n_cities == 100
|
||||
assert norm(data[0].distances - data[0].distances.T) < 1e-6
|
||||
d = [d.distances[0, 1] for d in data]
|
||||
assert np.std(d) > 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user