mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Branching: Make classifier configurable
This commit is contained in:
@@ -11,15 +11,22 @@ from tqdm.auto import tqdm
|
||||
from joblib import Parallel, delayed
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def _default_branch_priority_predictor():
|
||||
return KNeighborsRegressor(n_neighbors=1)
|
||||
|
||||
|
||||
class BranchPriorityComponent(Component):
|
||||
def __init__(self,
|
||||
node_limit=1_000,
|
||||
node_limit=10_000,
|
||||
predictor=_default_branch_priority_predictor,
|
||||
):
|
||||
self.pending_instances = []
|
||||
self.x_train = {}
|
||||
self.y_train = {}
|
||||
self.predictors = {}
|
||||
self.node_limit = node_limit
|
||||
self.predictor_factory = predictor
|
||||
|
||||
def before_solve(self, solver, instance, model):
|
||||
assert solver.is_persistent, "BranchPriorityComponent requires a persistent solver"
|
||||
@@ -92,7 +99,7 @@ class BranchPriorityComponent(Component):
|
||||
|
||||
return comp
|
||||
|
||||
|
||||
# Run strong branching on pending instances
|
||||
subcomponents = Parallel(n_jobs=n_jobs)(
|
||||
delayed(_process)(instance)
|
||||
for instance in tqdm(self.pending_instances, desc="Branch priority")
|
||||
@@ -104,10 +111,9 @@ class BranchPriorityComponent(Component):
|
||||
for category in self.x_train.keys():
|
||||
x_train = self.x_train[category]
|
||||
y_train = self.y_train[category]
|
||||
self.predictors[category] = KNeighborsRegressor(n_neighbors=1)
|
||||
self.predictors[category] = self.predictor_factory()
|
||||
self.predictors[category].fit(x_train, y_train)
|
||||
|
||||
|
||||
def _build_x(self, instance, var, index):
|
||||
instance_features = instance.get_instance_features()
|
||||
var_features = instance.get_variable_features(var, index)
|
||||
|
||||
Reference in New Issue
Block a user