From 2679522b76e2a5159ce52feaa43b5be1c4cf6dc2 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Fri, 31 Jan 2020 11:06:07 -0600 Subject: [PATCH] Merge components more efficiently --- miplearn/branching.py | 7 +++--- miplearn/problems/tests/test_knapsack.py | 2 +- miplearn/solvers.py | 10 +++++---- miplearn/warmstart.py | 27 +++++++++++++++--------- 4 files changed, 28 insertions(+), 18 deletions(-) diff --git a/miplearn/branching.py b/miplearn/branching.py index 0654953..3163a80 100644 --- a/miplearn/branching.py +++ b/miplearn/branching.py @@ -50,9 +50,10 @@ class BranchPriorityComponent(Component): pass - def merge(self, other): - if other.priority is not None: - self._merge(other.priority) + def merge(self, other_components): + for comp in other_components: + if comp.priority is not None: + self._merge(comp.priority) def _merge(self, priority): diff --git a/miplearn/problems/tests/test_knapsack.py b/miplearn/problems/tests/test_knapsack.py index 3cbd960..31d31da 100644 --- a/miplearn/problems/tests/test_knapsack.py +++ b/miplearn/problems/tests/test_knapsack.py @@ -66,7 +66,7 @@ def test_knapsack_fixed_weights_jitter(): fix_w=True, w_jitter=randint(low=0, high=1001), ) - instances = gen.generate(1_000) + instances = gen.generate(5_000) w = [instance.weights[0,0] for instance in instances] assert round(np.std(w), -1) == 290. assert round(np.mean(w), -2) == 1500. \ No newline at end of file diff --git a/miplearn/solvers.py b/miplearn/solvers.py index 69b7020..c8b46b8 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -94,12 +94,14 @@ class LearningSolver: for instance in tqdm(instances, desc=label, ncols=80) ) - solvers = [p[0] for p in solver_result_pairs] + subsolvers = [p[0] for p in solver_result_pairs] results = [p[1] for p in solver_result_pairs] for (name, component) in self.components.items(): - for subsolver in solvers: - self.components[name].merge(subsolver.components[name]) + subcomponents = [subsolver.components[name] + for subsolver in subsolvers + if name in subsolver.components.keys()] + self.components[name].merge(subcomponents) return results @@ -122,4 +124,4 @@ class LearningSolver: if component_name not in self.components.keys(): continue else: - self.components[component_name].merge(component) + self.components[component_name].merge([component]) diff --git a/miplearn/warmstart.py b/miplearn/warmstart.py index 9e003c8..9a75a7e 100644 --- a/miplearn/warmstart.py +++ b/miplearn/warmstart.py @@ -180,13 +180,20 @@ class WarmStartComponent(Component): self.predictors[category] = deepcopy(self.predictor_prototype) self.predictors[category].fit(x_train, y_train) - def merge(self, other): - for c in other.x_train.keys(): - if c not in self.x_train: - self.x_train[c] = other.x_train[c] - self.y_train[c] = other.y_train[c] - else: - self.x_train[c] = np.vstack([self.x_train[c], other.x_train[c]]) - self.y_train[c] = np.vstack([self.y_train[c], other.y_train[c]]) - if (c in other.predictors.keys()) and (c not in self.predictors.keys()): - self.predictors[c] = other.predictors[c] \ No newline at end of file + def merge(self, other_components): + keys = set(self.x_train.keys()) + for comp in other_components: + keys = keys.union(set(comp.x_train.keys())) + + for key in keys: + x_train_submatrices = [comp.x_train[key] + for comp in other_components + if key in comp.x_train.keys()] + y_train_submatrices = [comp.y_train[key] + for comp in other_components + if key in comp.y_train.keys()] + if key in self.x_train.keys(): + x_train_submatrices += [self.x_train[key]] + y_train_submatrices += [self.y_train[key]] + self.x_train[key] = np.vstack(x_train_submatrices) + self.y_train[key] = np.vstack(y_train_submatrices)