mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Correctly store features and training data for file-based instances
This commit is contained in:
@@ -128,6 +128,7 @@ class TravelingSalesmanInstance(Instance):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_cities, distances):
|
def __init__(self, n_cities, distances):
|
||||||
|
super().__init__()
|
||||||
assert isinstance(distances, np.ndarray)
|
assert isinstance(distances, np.ndarray)
|
||||||
assert distances.shape == (n_cities, n_cities)
|
assert distances.shape == (n_cities, n_cities)
|
||||||
self.n_cities = n_cities
|
self.n_cities = n_cities
|
||||||
|
|||||||
@@ -135,8 +135,6 @@ class LearningSolver:
|
|||||||
|
|
||||||
# Initialize training sample
|
# Initialize training sample
|
||||||
training_sample: TrainingSample = {}
|
training_sample: TrainingSample = {}
|
||||||
if not hasattr(instance, "training_data"):
|
|
||||||
instance.training_data = []
|
|
||||||
instance.training_data += [training_sample]
|
instance.training_data += [training_sample]
|
||||||
|
|
||||||
# Initialize stats
|
# Initialize stats
|
||||||
@@ -151,7 +149,8 @@ class LearningSolver:
|
|||||||
|
|
||||||
# Extract features
|
# Extract features
|
||||||
extractor = FeaturesExtractor(self.internal_solver)
|
extractor = FeaturesExtractor(self.internal_solver)
|
||||||
instance.features = extractor.extract(instance)
|
instance.features.clear() # type: ignore
|
||||||
|
instance.features.update(extractor.extract(instance))
|
||||||
|
|
||||||
callback_args = (
|
callback_args = (
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -91,12 +91,14 @@ def test_solve_fit_from_disk():
|
|||||||
solver.solve(instances[0])
|
solver.solve(instances[0])
|
||||||
instance_loaded = read_pickle_gz(instances[0].filename)
|
instance_loaded = read_pickle_gz(instances[0].filename)
|
||||||
assert len(instance_loaded.training_data) > 0
|
assert len(instance_loaded.training_data) > 0
|
||||||
|
assert len(instance_loaded.features) > 0
|
||||||
|
|
||||||
# Test: parallel_solve
|
# Test: parallel_solve
|
||||||
solver.parallel_solve(instances)
|
solver.parallel_solve(instances)
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
instance_loaded = read_pickle_gz(instance.filename)
|
instance_loaded = read_pickle_gz(instance.filename)
|
||||||
assert len(instance.training_data) > 0
|
assert len(instance_loaded.training_data) > 0
|
||||||
|
assert len(instance_loaded.features) > 0
|
||||||
|
|
||||||
# Delete temporary files
|
# Delete temporary files
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
|
|||||||
Reference in New Issue
Block a user