mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make TravelingSalesmanGenerator return data class
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import List, Tuple, Any, Optional, Dict
|
from typing import List, Tuple, Any, Optional, Dict
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -17,28 +18,10 @@ from miplearn.solvers.pyomo.base import BasePyomoSolver
|
|||||||
from miplearn.types import ConstraintName
|
from miplearn.types import ConstraintName
|
||||||
|
|
||||||
|
|
||||||
class ChallengeA:
|
@dataclass
|
||||||
def __init__(
|
class TravelingSalesmanData:
|
||||||
self,
|
n_cities: int
|
||||||
seed: int = 42,
|
distances: np.ndarray
|
||||||
n_training_instances: int = 500,
|
|
||||||
n_test_instances: int = 50,
|
|
||||||
) -> None:
|
|
||||||
np.random.seed(seed)
|
|
||||||
self.generator = TravelingSalesmanGenerator(
|
|
||||||
x=uniform(loc=0.0, scale=1000.0),
|
|
||||||
y=uniform(loc=0.0, scale=1000.0),
|
|
||||||
n=randint(low=350, high=351),
|
|
||||||
gamma=uniform(loc=0.95, scale=0.1),
|
|
||||||
fix_cities=True,
|
|
||||||
round=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
np.random.seed(seed + 1)
|
|
||||||
self.training_instances = self.generator.generate(n_training_instances)
|
|
||||||
|
|
||||||
np.random.seed(seed + 2)
|
|
||||||
self.test_instances = self.generator.generate(n_test_instances)
|
|
||||||
|
|
||||||
|
|
||||||
class TravelingSalesmanInstance(Instance):
|
class TravelingSalesmanInstance(Instance):
|
||||||
@@ -180,8 +163,8 @@ class TravelingSalesmanGenerator:
|
|||||||
self.fixed_n = None
|
self.fixed_n = None
|
||||||
self.fixed_cities = None
|
self.fixed_cities = None
|
||||||
|
|
||||||
def generate(self, n_samples: int) -> List[TravelingSalesmanInstance]:
|
def generate(self, n_samples: int) -> List[TravelingSalesmanData]:
|
||||||
def _sample() -> TravelingSalesmanInstance:
|
def _sample() -> TravelingSalesmanData:
|
||||||
if self.fixed_cities is not None:
|
if self.fixed_cities is not None:
|
||||||
assert self.fixed_n is not None
|
assert self.fixed_n is not None
|
||||||
n, cities = self.fixed_n, self.fixed_cities
|
n, cities = self.fixed_n, self.fixed_cities
|
||||||
@@ -191,7 +174,7 @@ class TravelingSalesmanGenerator:
|
|||||||
distances = np.tril(distances) + np.triu(distances.T, 1)
|
distances = np.tril(distances) + np.triu(distances.T, 1)
|
||||||
if self.round:
|
if self.round:
|
||||||
distances = distances.round()
|
distances = distances.round()
|
||||||
return TravelingSalesmanInstance(n, distances)
|
return TravelingSalesmanData(n, distances)
|
||||||
|
|
||||||
return [_sample() for _ in range(n_samples)]
|
return [_sample() for _ in range(n_samples)]
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold
|
|||||||
from miplearn.components import classifier_evaluation_dict
|
from miplearn.components import classifier_evaluation_dict
|
||||||
from miplearn.components.primal import PrimalSolutionComponent
|
from miplearn.components.primal import PrimalSolutionComponent
|
||||||
from miplearn.features.sample import Sample, MemorySample
|
from miplearn.features.sample import Sample, MemorySample
|
||||||
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance
|
||||||
from miplearn.solvers.learning import LearningSolver
|
from miplearn.solvers.learning import LearningSolver
|
||||||
from miplearn.solvers.tests import assert_equals
|
from miplearn.solvers.tests import assert_equals
|
||||||
|
|
||||||
@@ -108,7 +108,8 @@ def test_usage() -> None:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
|
||||||
instance = gen.generate(1)[0]
|
data = gen.generate(1)
|
||||||
|
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
|
||||||
solver.solve(instance)
|
solver.solve(instance)
|
||||||
solver.fit([instance])
|
solver.fit([instance])
|
||||||
stats = solver.solve(instance)
|
stats = solver.solve(instance)
|
||||||
|
|||||||
@@ -14,17 +14,17 @@ from miplearn.solvers.tests import assert_equals
|
|||||||
|
|
||||||
|
|
||||||
def test_generator() -> None:
|
def test_generator() -> None:
|
||||||
instances = TravelingSalesmanGenerator(
|
data = TravelingSalesmanGenerator(
|
||||||
x=uniform(loc=0.0, scale=1000.0),
|
x=uniform(loc=0.0, scale=1000.0),
|
||||||
y=uniform(loc=0.0, scale=1000.0),
|
y=uniform(loc=0.0, scale=1000.0),
|
||||||
n=randint(low=100, high=101),
|
n=randint(low=100, high=101),
|
||||||
gamma=uniform(loc=0.95, scale=0.1),
|
gamma=uniform(loc=0.95, scale=0.1),
|
||||||
fix_cities=True,
|
fix_cities=True,
|
||||||
).generate(100)
|
).generate(100)
|
||||||
assert len(instances) == 100
|
assert len(data) == 100
|
||||||
assert instances[0].n_cities == 100
|
assert data[0].n_cities == 100
|
||||||
assert norm(instances[0].distances - instances[0].distances.T) < 1e-6
|
assert norm(data[0].distances - data[0].distances.T) < 1e-6
|
||||||
d = [instance.distances[0, 1] for instance in instances]
|
d = [d.distances[0, 1] for d in data]
|
||||||
assert np.std(d) > 0
|
assert np.std(d) > 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user