|
|
|
@ -11,15 +11,22 @@ from tqdm.auto import tqdm
|
|
|
|
|
from joblib import Parallel, delayed
|
|
|
|
|
import multiprocessing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _default_branch_priority_predictor():
|
|
|
|
|
return KNeighborsRegressor(n_neighbors=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BranchPriorityComponent(Component):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
node_limit=1_000,
|
|
|
|
|
node_limit=10_000,
|
|
|
|
|
predictor=_default_branch_priority_predictor,
|
|
|
|
|
):
|
|
|
|
|
self.pending_instances = []
|
|
|
|
|
self.x_train = {}
|
|
|
|
|
self.y_train = {}
|
|
|
|
|
self.predictors = {}
|
|
|
|
|
self.node_limit = node_limit
|
|
|
|
|
self.predictor_factory = predictor
|
|
|
|
|
|
|
|
|
|
def before_solve(self, solver, instance, model):
|
|
|
|
|
assert solver.is_persistent, "BranchPriorityComponent requires a persistent solver"
|
|
|
|
@ -92,7 +99,7 @@ class BranchPriorityComponent(Component):
|
|
|
|
|
|
|
|
|
|
return comp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Run strong branching on pending instances
|
|
|
|
|
subcomponents = Parallel(n_jobs=n_jobs)(
|
|
|
|
|
delayed(_process)(instance)
|
|
|
|
|
for instance in tqdm(self.pending_instances, desc="Branch priority")
|
|
|
|
@ -104,9 +111,8 @@ class BranchPriorityComponent(Component):
|
|
|
|
|
for category in self.x_train.keys():
|
|
|
|
|
x_train = self.x_train[category]
|
|
|
|
|
y_train = self.y_train[category]
|
|
|
|
|
self.predictors[category] = KNeighborsRegressor(n_neighbors=1)
|
|
|
|
|
self.predictors[category] = self.predictor_factory()
|
|
|
|
|
self.predictors[category].fit(x_train, y_train)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_x(self, instance, var, index):
|
|
|
|
|
instance_features = instance.get_instance_features()
|
|
|
|
|