From 959cb54d2792a4e7c16b14fa4f6fa311b5343753 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Mon, 24 Feb 2020 22:16:51 -0600 Subject: [PATCH] Remove load_state and save_state --- miplearn/components/branching.py | 26 ++++++++++++++++++++++++++ miplearn/solvers.py | 18 ------------------ miplearn/tests/test_benchmark.py | 4 +--- miplearn/tests/test_solver.py | 23 ----------------------- 4 files changed, 27 insertions(+), 44 deletions(-) diff --git a/miplearn/components/branching.py b/miplearn/components/branching.py index f35886e..8d29396 100644 --- a/miplearn/components/branching.py +++ b/miplearn/components/branching.py @@ -118,3 +118,29 @@ class BranchPriorityComponent(Component): instance_features = instance.get_instance_features() var_features = instance.get_variable_features(var, index) return np.hstack([instance_features, var_features]) + + def merge(self, other_components): + keys = set(self.x_train.keys()) + for comp in other_components: + self.pending_instances += comp.pending_instances + keys = keys.union(set(comp.x_train.keys())) + + # Merge x_train and y_train + for key in keys: + x_train_submatrices = [comp.x_train[key] + for comp in other_components + if key in comp.x_train.keys()] + y_train_submatrices = [comp.y_train[key] + for comp in other_components + if key in comp.y_train.keys()] + if key in self.x_train.keys(): + x_train_submatrices += [self.x_train[key]] + y_train_submatrices += [self.y_train[key]] + self.x_train[key] = np.vstack(x_train_submatrices) + self.y_train[key] = np.vstack(y_train_submatrices) + + # Merge trained ML predictors + for comp in other_components: + for key in comp.predictors.keys(): + if key not in self.predictors.keys(): + self.predictors[key] = comp.predictors[key] \ No newline at end of file diff --git a/miplearn/solvers.py b/miplearn/solvers.py index d410cab..ef547ae 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -327,21 +327,3 @@ class LearningSolver: return for component in self.components.values(): component.fit(training_instances) - - def save_state(self, filename): - with open(filename, "wb") as file: - pickle.dump({ - "version": 2, - "components": self.components, - }, file) - - def load_state(self, filename): - with open(filename, "rb") as file: - data = pickle.load(file) - assert data["version"] == 2 - for (component_name, component) in data["components"].items(): - if component_name not in self.components.keys(): - continue - else: - self.components[component_name].merge([component]) - diff --git a/miplearn/tests/test_benchmark.py b/miplearn/tests/test_benchmark.py index 094eee4..523ba61 100644 --- a/miplearn/tests/test_benchmark.py +++ b/miplearn/tests/test_benchmark.py @@ -18,8 +18,6 @@ def test_benchmark(): # Training phase... training_solver = LearningSolver() training_solver.parallel_solve(train_instances, n_jobs=10) - training_solver.fit() - training_solver.save_state("data.bin") # Test phase... test_solvers = { @@ -27,7 +25,7 @@ def test_benchmark(): "Strategy B": LearningSolver(), } benchmark = BenchmarkRunner(test_solvers) - benchmark.load_state("data.bin") + benchmark.fit(train_instances) benchmark.parallel_solve(test_instances, n_jobs=2, n_trials=2) assert benchmark.raw_results().values.shape == (12,13) diff --git a/miplearn/tests/test_solver.py b/miplearn/tests/test_solver.py index c8b662e..24d879b 100644 --- a/miplearn/tests/test_solver.py +++ b/miplearn/tests/test_solver.py @@ -41,29 +41,6 @@ def test_solver(): solver.fit() solver.solve(instance) - -# def test_solve_save_load_state(): -# instance = _get_instance() -# components_before = { -# "warm-start": WarmStartComponent(), -# } -# solver = LearningSolver(components=components_before) -# solver.solve(instance) -# solver.fit() -# solver.save_state("/tmp/knapsack_train.bin") -# prev_x_train_len = len(solver.components["warm-start"].x_train) -# prev_y_train_len = len(solver.components["warm-start"].y_train) - -# components_after = { -# "warm-start": WarmStartComponent(), -# } -# solver = LearningSolver(components=components_after) -# solver.load_state("/tmp/knapsack_train.bin") -# assert len(solver.components.keys()) == 1 -# assert len(solver.components["warm-start"].x_train) == prev_x_train_len -# assert len(solver.components["warm-start"].y_train) == prev_y_train_len - - def test_parallel_solve(): instances = [_get_instance() for _ in range(10)] solver = LearningSolver()