From c3902ad61c89f34e6d04f5f5d90ad1a7893082d3 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Fri, 21 Feb 2020 11:20:21 -0600 Subject: [PATCH] Branching: Make classifier configurable --- miplearn/components/branching.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/miplearn/components/branching.py b/miplearn/components/branching.py index be6c871..e716f8c 100644 --- a/miplearn/components/branching.py +++ b/miplearn/components/branching.py @@ -11,15 +11,22 @@ from tqdm.auto import tqdm from joblib import Parallel, delayed import multiprocessing + +def _default_branch_priority_predictor(): + return KNeighborsRegressor(n_neighbors=1) + + class BranchPriorityComponent(Component): def __init__(self, - node_limit=1_000, + node_limit=10_000, + predictor=_default_branch_priority_predictor, ): self.pending_instances = [] self.x_train = {} self.y_train = {} self.predictors = {} self.node_limit = node_limit + self.predictor_factory = predictor def before_solve(self, solver, instance, model): assert solver.is_persistent, "BranchPriorityComponent requires a persistent solver" @@ -92,7 +99,7 @@ class BranchPriorityComponent(Component): return comp - + # Run strong branching on pending instances subcomponents = Parallel(n_jobs=n_jobs)( delayed(_process)(instance) for instance in tqdm(self.pending_instances, desc="Branch priority") @@ -104,9 +111,8 @@ class BranchPriorityComponent(Component): for category in self.x_train.keys(): x_train = self.x_train[category] y_train = self.y_train[category] - self.predictors[category] = KNeighborsRegressor(n_neighbors=1) + self.predictors[category] = self.predictor_factory() self.predictors[category].fit(x_train, y_train) - def _build_x(self, instance, var, index): instance_features = instance.get_instance_features()