Correctly store features and training data for file-based instances

master
Alinson S. Xavier 5 years ago
parent 025e08f85e
commit f2520f33fb
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -128,6 +128,7 @@ class TravelingSalesmanInstance(Instance):
""" """
def __init__(self, n_cities, distances): def __init__(self, n_cities, distances):
super().__init__()
assert isinstance(distances, np.ndarray) assert isinstance(distances, np.ndarray)
assert distances.shape == (n_cities, n_cities) assert distances.shape == (n_cities, n_cities)
self.n_cities = n_cities self.n_cities = n_cities

@ -135,8 +135,6 @@ class LearningSolver:
# Initialize training sample # Initialize training sample
training_sample: TrainingSample = {} training_sample: TrainingSample = {}
if not hasattr(instance, "training_data"):
instance.training_data = []
instance.training_data += [training_sample] instance.training_data += [training_sample]
# Initialize stats # Initialize stats
@ -151,7 +149,8 @@ class LearningSolver:
# Extract features # Extract features
extractor = FeaturesExtractor(self.internal_solver) extractor = FeaturesExtractor(self.internal_solver)
instance.features = extractor.extract(instance) instance.features.clear() # type: ignore
instance.features.update(extractor.extract(instance))
callback_args = ( callback_args = (
self, self,

@ -91,12 +91,14 @@ def test_solve_fit_from_disk():
solver.solve(instances[0]) solver.solve(instances[0])
instance_loaded = read_pickle_gz(instances[0].filename) instance_loaded = read_pickle_gz(instances[0].filename)
assert len(instance_loaded.training_data) > 0 assert len(instance_loaded.training_data) > 0
assert len(instance_loaded.features) > 0
# Test: parallel_solve # Test: parallel_solve
solver.parallel_solve(instances) solver.parallel_solve(instances)
for instance in instances: for instance in instances:
instance_loaded = read_pickle_gz(instance.filename) instance_loaded = read_pickle_gz(instance.filename)
assert len(instance.training_data) > 0 assert len(instance_loaded.training_data) > 0
assert len(instance_loaded.features) > 0
# Delete temporary files # Delete temporary files
for instance in instances: for instance in instances:

Loading…
Cancel
Save