Implement TravelingSalesmanPerturber

This commit is contained in:
2025-12-08 15:10:24 -06:00
parent 4137378bb8
commit 1d44980a7b
23 changed files with 128 additions and 90 deletions

View File

@@ -17,56 +17,30 @@ def test_tsp_generator() -> None:
gen = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=3, high=4),
n=randint(low=5, high=6),
gamma=uniform(loc=1.0, scale=0.25),
fix_cities=True,
round=True,
)
data = gen.generate(2)
data = gen.generate(1)
assert data[0].distances.tolist() == [
[0.0, 591.0, 996.0],
[591.0, 0.0, 765.0],
[996.0, 765.0, 0.0],
]
assert data[1].distances.tolist() == [
[0.0, 556.0, 853.0],
[556.0, 0.0, 779.0],
[853.0, 779.0, 0.0],
[0.0, 525.0, 950.0, 392.0, 382.0],
[525.0, 0.0, 752.0, 761.0, 178.0],
[950.0, 752.0, 0.0, 809.0, 721.0],
[392.0, 761.0, 809.0, 0.0, 700.0],
[382.0, 178.0, 721.0, 700.0, 0.0],
]
def test_tsp() -> None:
data = TravelingSalesmanData(
n_cities=6,
distances=squareform(
pdist(
[
[0.0, 0.0],
[1.0, 0.0],
[2.0, 0.0],
[3.0, 0.0],
[0.0, 1.0],
[3.0, 1.0],
]
)
),
)
model = build_tsp_model_gurobipy(data)
model = build_tsp_model_gurobipy(data[0])
model.optimize()
assert model.inner.getAttr("x", model.inner.getVars()) == [
1.0,
0.0,
0.0,
1.0,
0.0,
1.0,
0.0,
0.0,
0.0,
1.0,
0.0,
0.0,
0.0,
1.0,
1.0,
0.0,
0.0,
]