Remove load_state and save_state

pull/1/head
Alinson S. Xavier 6 years ago
parent 68e972c635
commit 959cb54d27
No known key found for this signature in database
GPG Key ID: A796166E4E218E02

@ -118,3 +118,29 @@ class BranchPriorityComponent(Component):
instance_features = instance.get_instance_features() instance_features = instance.get_instance_features()
var_features = instance.get_variable_features(var, index) var_features = instance.get_variable_features(var, index)
return np.hstack([instance_features, var_features]) return np.hstack([instance_features, var_features])
def merge(self, other_components):
keys = set(self.x_train.keys())
for comp in other_components:
self.pending_instances += comp.pending_instances
keys = keys.union(set(comp.x_train.keys()))
# Merge x_train and y_train
for key in keys:
x_train_submatrices = [comp.x_train[key]
for comp in other_components
if key in comp.x_train.keys()]
y_train_submatrices = [comp.y_train[key]
for comp in other_components
if key in comp.y_train.keys()]
if key in self.x_train.keys():
x_train_submatrices += [self.x_train[key]]
y_train_submatrices += [self.y_train[key]]
self.x_train[key] = np.vstack(x_train_submatrices)
self.y_train[key] = np.vstack(y_train_submatrices)
# Merge trained ML predictors
for comp in other_components:
for key in comp.predictors.keys():
if key not in self.predictors.keys():
self.predictors[key] = comp.predictors[key]

@ -327,21 +327,3 @@ class LearningSolver:
return return
for component in self.components.values(): for component in self.components.values():
component.fit(training_instances) component.fit(training_instances)
def save_state(self, filename):
with open(filename, "wb") as file:
pickle.dump({
"version": 2,
"components": self.components,
}, file)
def load_state(self, filename):
with open(filename, "rb") as file:
data = pickle.load(file)
assert data["version"] == 2
for (component_name, component) in data["components"].items():
if component_name not in self.components.keys():
continue
else:
self.components[component_name].merge([component])

@ -18,8 +18,6 @@ def test_benchmark():
# Training phase... # Training phase...
training_solver = LearningSolver() training_solver = LearningSolver()
training_solver.parallel_solve(train_instances, n_jobs=10) training_solver.parallel_solve(train_instances, n_jobs=10)
training_solver.fit()
training_solver.save_state("data.bin")
# Test phase... # Test phase...
test_solvers = { test_solvers = {
@ -27,7 +25,7 @@ def test_benchmark():
"Strategy B": LearningSolver(), "Strategy B": LearningSolver(),
} }
benchmark = BenchmarkRunner(test_solvers) benchmark = BenchmarkRunner(test_solvers)
benchmark.load_state("data.bin") benchmark.fit(train_instances)
benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2) benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2)
assert benchmark.raw_results().values.shape == (12,13) assert benchmark.raw_results().values.shape == (12,13)

@ -41,29 +41,6 @@ def test_solver():
solver.fit() solver.fit()
solver.solve(instance) solver.solve(instance)
# def test_solve_save_load_state():
# instance = _get_instance()
# components_before = {
# "warm-start": WarmStartComponent(),
# }
# solver = LearningSolver(components=components_before)
# solver.solve(instance)
# solver.fit()
# solver.save_state("/tmp/knapsack_train.bin")
# prev_x_train_len = len(solver.components["warm-start"].x_train)
# prev_y_train_len = len(solver.components["warm-start"].y_train)
# components_after = {
# "warm-start": WarmStartComponent(),
# }
# solver = LearningSolver(components=components_after)
# solver.load_state("/tmp/knapsack_train.bin")
# assert len(solver.components.keys()) == 1
# assert len(solver.components["warm-start"].x_train) == prev_x_train_len
# assert len(solver.components["warm-start"].y_train) == prev_y_train_len
def test_parallel_solve(): def test_parallel_solve():
instances = [_get_instance() for _ in range(10)] instances = [_get_instance() for _ in range(10)]
solver = LearningSolver() solver = LearningSolver()

Loading…
Cancel
Save