Make TravelingSalesmanGenerator return data class

This commit is contained in:
2022-02-22 09:23:55 -06:00
parent 03e5acb11a
commit 87bba1b38e
3 changed files with 16 additions and 32 deletions

View File

@@ -13,7 +13,7 @@ from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.features.sample import Sample, MemorySample
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.problems.tsp import TravelingSalesmanGenerator, TravelingSalesmanInstance
from miplearn.solvers.learning import LearningSolver
from miplearn.solvers.tests import assert_equals
@@ -108,7 +108,8 @@ def test_usage() -> None:
]
)
gen = TravelingSalesmanGenerator(n=randint(low=5, high=6))
instance = gen.generate(1)[0]
data = gen.generate(1)
instance = TravelingSalesmanInstance(data[0].n_cities, data[0].distances)
solver.solve(instance)
solver.fit([instance])
stats = solver.solve(instance)

View File

@@ -14,17 +14,17 @@ from miplearn.solvers.tests import assert_equals
def test_generator() -> None:
instances = TravelingSalesmanGenerator(
data = TravelingSalesmanGenerator(
x=uniform(loc=0.0, scale=1000.0),
y=uniform(loc=0.0, scale=1000.0),
n=randint(low=100, high=101),
gamma=uniform(loc=0.95, scale=0.1),
fix_cities=True,
).generate(100)
assert len(instances) == 100
assert instances[0].n_cities == 100
assert norm(instances[0].distances - instances[0].distances.T) < 1e-6
d = [instance.distances[0, 1] for instance in instances]
assert len(data) == 100
assert data[0].n_cities == 100
assert norm(data[0].distances - data[0].distances.T) < 1e-6
d = [d.distances[0, 1] for d in data]
assert np.std(d) > 0