From f2520f33fb8928361b26512a1d250c0ce63ce102 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sun, 4 Apr 2021 22:00:21 -0500 Subject: [PATCH] Correctly store features and training data for file-based instances --- miplearn/problems/tsp.py | 1 + miplearn/solvers/learning.py | 5 ++--- tests/solvers/test_learning_solver.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/miplearn/problems/tsp.py b/miplearn/problems/tsp.py index 439c489..e6d0cb8 100644 --- a/miplearn/problems/tsp.py +++ b/miplearn/problems/tsp.py @@ -128,6 +128,7 @@ class TravelingSalesmanInstance(Instance): """ def __init__(self, n_cities, distances): + super().__init__() assert isinstance(distances, np.ndarray) assert distances.shape == (n_cities, n_cities) self.n_cities = n_cities diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index eede7e6..f3e5aa8 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -135,8 +135,6 @@ class LearningSolver: # Initialize training sample training_sample: TrainingSample = {} - if not hasattr(instance, "training_data"): - instance.training_data = [] instance.training_data += [training_sample] # Initialize stats @@ -151,7 +149,8 @@ class LearningSolver: # Extract features extractor = FeaturesExtractor(self.internal_solver) - instance.features = extractor.extract(instance) + instance.features.clear() # type: ignore + instance.features.update(extractor.extract(instance)) callback_args = ( self, diff --git a/tests/solvers/test_learning_solver.py b/tests/solvers/test_learning_solver.py index 913eb6b..11502a2 100644 --- a/tests/solvers/test_learning_solver.py +++ b/tests/solvers/test_learning_solver.py @@ -91,12 +91,14 @@ def test_solve_fit_from_disk(): solver.solve(instances[0]) instance_loaded = read_pickle_gz(instances[0].filename) assert len(instance_loaded.training_data) > 0 + assert len(instance_loaded.features) > 0 # Test: parallel_solve solver.parallel_solve(instances) for instance in instances: instance_loaded = read_pickle_gz(instance.filename) - assert len(instance.training_data) > 0 + assert len(instance_loaded.training_data) > 0 + assert len(instance_loaded.features) > 0 # Delete temporary files for instance in instances: