mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Merge components more efficiently
This commit is contained in:
@@ -50,9 +50,10 @@ class BranchPriorityComponent(Component):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other_components):
|
||||||
if other.priority is not None:
|
for comp in other_components:
|
||||||
self._merge(other.priority)
|
if comp.priority is not None:
|
||||||
|
self._merge(comp.priority)
|
||||||
|
|
||||||
|
|
||||||
def _merge(self, priority):
|
def _merge(self, priority):
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ def test_knapsack_fixed_weights_jitter():
|
|||||||
fix_w=True,
|
fix_w=True,
|
||||||
w_jitter=randint(low=0, high=1001),
|
w_jitter=randint(low=0, high=1001),
|
||||||
)
|
)
|
||||||
instances = gen.generate(1_000)
|
instances = gen.generate(5_000)
|
||||||
w = [instance.weights[0,0] for instance in instances]
|
w = [instance.weights[0,0] for instance in instances]
|
||||||
assert round(np.std(w), -1) == 290.
|
assert round(np.std(w), -1) == 290.
|
||||||
assert round(np.mean(w), -2) == 1500.
|
assert round(np.mean(w), -2) == 1500.
|
||||||
@@ -94,12 +94,14 @@ class LearningSolver:
|
|||||||
for instance in tqdm(instances, desc=label, ncols=80)
|
for instance in tqdm(instances, desc=label, ncols=80)
|
||||||
)
|
)
|
||||||
|
|
||||||
solvers = [p[0] for p in solver_result_pairs]
|
subsolvers = [p[0] for p in solver_result_pairs]
|
||||||
results = [p[1] for p in solver_result_pairs]
|
results = [p[1] for p in solver_result_pairs]
|
||||||
|
|
||||||
for (name, component) in self.components.items():
|
for (name, component) in self.components.items():
|
||||||
for subsolver in solvers:
|
subcomponents = [subsolver.components[name]
|
||||||
self.components[name].merge(subsolver.components[name])
|
for subsolver in subsolvers
|
||||||
|
if name in subsolver.components.keys()]
|
||||||
|
self.components[name].merge(subcomponents)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -122,4 +124,4 @@ class LearningSolver:
|
|||||||
if component_name not in self.components.keys():
|
if component_name not in self.components.keys():
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
self.components[component_name].merge(component)
|
self.components[component_name].merge([component])
|
||||||
|
|||||||
@@ -180,13 +180,20 @@ class WarmStartComponent(Component):
|
|||||||
self.predictors[category] = deepcopy(self.predictor_prototype)
|
self.predictors[category] = deepcopy(self.predictor_prototype)
|
||||||
self.predictors[category].fit(x_train, y_train)
|
self.predictors[category].fit(x_train, y_train)
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other_components):
|
||||||
for c in other.x_train.keys():
|
keys = set(self.x_train.keys())
|
||||||
if c not in self.x_train:
|
for comp in other_components:
|
||||||
self.x_train[c] = other.x_train[c]
|
keys = keys.union(set(comp.x_train.keys()))
|
||||||
self.y_train[c] = other.y_train[c]
|
|
||||||
else:
|
for key in keys:
|
||||||
self.x_train[c] = np.vstack([self.x_train[c], other.x_train[c]])
|
x_train_submatrices = [comp.x_train[key]
|
||||||
self.y_train[c] = np.vstack([self.y_train[c], other.y_train[c]])
|
for comp in other_components
|
||||||
if (c in other.predictors.keys()) and (c not in self.predictors.keys()):
|
if key in comp.x_train.keys()]
|
||||||
self.predictors[c] = other.predictors[c]
|
y_train_submatrices = [comp.y_train[key]
|
||||||
|
for comp in other_components
|
||||||
|
if key in comp.y_train.keys()]
|
||||||
|
if key in self.x_train.keys():
|
||||||
|
x_train_submatrices += [self.x_train[key]]
|
||||||
|
y_train_submatrices += [self.y_train[key]]
|
||||||
|
self.x_train[key] = np.vstack(x_train_submatrices)
|
||||||
|
self.y_train[key] = np.vstack(y_train_submatrices)
|
||||||
|
|||||||
Reference in New Issue
Block a user