Branching: Make classifier configurable

pull/1/head
Alinson S. Xavier 6 years ago
parent d7131e9f66
commit c3902ad61c

@ -11,15 +11,22 @@ from tqdm.auto import tqdm
from joblib import Parallel, delayed from joblib import Parallel, delayed
import multiprocessing import multiprocessing
def _default_branch_priority_predictor():
return KNeighborsRegressor(n_neighbors=1)
class BranchPriorityComponent(Component): class BranchPriorityComponent(Component):
def __init__(self, def __init__(self,
node_limit=1_000, node_limit=10_000,
predictor=_default_branch_priority_predictor,
): ):
self.pending_instances = [] self.pending_instances = []
self.x_train = {} self.x_train = {}
self.y_train = {} self.y_train = {}
self.predictors = {} self.predictors = {}
self.node_limit = node_limit self.node_limit = node_limit
self.predictor_factory = predictor
def before_solve(self, solver, instance, model): def before_solve(self, solver, instance, model):
assert solver.is_persistent, "BranchPriorityComponent requires a persistent solver" assert solver.is_persistent, "BranchPriorityComponent requires a persistent solver"
@ -92,7 +99,7 @@ class BranchPriorityComponent(Component):
return comp return comp
# Run strong branching on pending instances
subcomponents = Parallel(n_jobs=n_jobs)( subcomponents = Parallel(n_jobs=n_jobs)(
delayed(_process)(instance) delayed(_process)(instance)
for instance in tqdm(self.pending_instances, desc="Branch priority") for instance in tqdm(self.pending_instances, desc="Branch priority")
@ -104,9 +111,8 @@ class BranchPriorityComponent(Component):
for category in self.x_train.keys(): for category in self.x_train.keys():
x_train = self.x_train[category] x_train = self.x_train[category]
y_train = self.y_train[category] y_train = self.y_train[category]
self.predictors[category] = KNeighborsRegressor(n_neighbors=1) self.predictors[category] = self.predictor_factory()
self.predictors[category].fit(x_train, y_train) self.predictors[category].fit(x_train, y_train)
def _build_x(self, instance, var, index): def _build_x(self, instance, var, index):
instance_features = instance.get_instance_features() instance_features = instance.get_instance_features()

Loading…
Cancel
Save