BenchmarkRunner.fit: Only iterate through files twice

master
Alinson S. Xavier 4 years ago
parent aaef8b8fb3
commit 46a7d3fe26
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -8,6 +8,7 @@ from typing import Dict, List
import pandas as pd
from miplearn.components.component import Component
from miplearn.instance.base import Instance
from miplearn.solvers.learning import LearningSolver
@ -106,10 +107,17 @@ class BenchmarkRunner:
----------
instances: List[Instance]
List of training instances.
n_jobs: int
Number of parallel processes to use.
"""
for (solver_name, solver) in self.solvers.items():
logger.debug(f"Fitting {solver_name}...")
solver.fit(instances, n_jobs=n_jobs)
components: List[Component] = []
for solver in self.solvers.values():
components += solver.components.values()
Component.fit_multiple(
components,
instances,
n_jobs=n_jobs,
)
def _silence_miplearn_logger(self) -> None:
miplearn_logger = logging.getLogger("miplearn")

@ -183,18 +183,19 @@ class Component:
@staticmethod
def fit_multiple(
components: Dict[str, "Component"],
components: List["Component"],
instances: List[Instance],
n_jobs: int = 1,
) -> None:
# Part I: Pre-fit
def _pre_sample_xy(instance: Instance) -> Dict:
pre_instance: Dict = {}
for (cname, comp) in components.items():
pre_instance[cname] = []
for (cidx, comp) in enumerate(components):
pre_instance[cidx] = []
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
pre_instance[cname].append(comp.pre_sample_xy(instance, sample))
for (cidx, comp) in enumerate(components):
pre_instance[cidx].append(comp.pre_sample_xy(instance, sample))
instance.free()
return pre_instance
@ -203,25 +204,25 @@ class Component:
else:
pre = p_umap(_pre_sample_xy, instances, num_cpus=n_jobs)
pre_combined: Dict = {}
for (cname, comp) in components.items():
pre_combined[cname] = []
for (cidx, comp) in enumerate(components):
pre_combined[cidx] = []
for p in pre:
pre_combined[cname].extend(p[cname])
for (cname, comp) in components.items():
comp.pre_fit(pre_combined[cname])
pre_combined[cidx].extend(p[cidx])
for (cidx, comp) in enumerate(components):
comp.pre_fit(pre_combined[cidx])
# Part II: Fit
def _sample_xy(instance: Instance) -> Tuple[Dict, Dict]:
x_instance: Dict = {}
y_instance: Dict = {}
for (cname, comp) in components.items():
x_instance[cname] = {}
y_instance[cname] = {}
for (cidx, comp) in enumerate(components):
x_instance[cidx] = {}
y_instance[cidx] = {}
instance.load()
for sample in instance.samples:
for (cname, comp) in components.items():
x = x_instance[cname]
y = y_instance[cname]
for (cidx, comp) in enumerate(components):
x = x_instance[cidx]
y = y_instance[cidx]
x_sample, y_sample = comp.sample_xy(instance, sample)
for cat in x_sample.keys():
if cat not in x:
@ -236,17 +237,16 @@ class Component:
xy = [_sample_xy(instance) for instance in instances]
else:
xy = p_umap(_sample_xy, instances)
for (cname, comp) in components.items():
for (cidx, comp) in enumerate(components):
x_comp: Dict = {}
y_comp: Dict = {}
for (x, y) in xy:
for cat in x[cname].keys():
for cat in x[cidx].keys():
if cat not in x_comp:
x_comp[cat] = []
y_comp[cat] = []
x_comp[cat].extend(x[cname][cat])
y_comp[cat].extend(y[cname][cat])
x_comp[cat].extend(x[cidx][cat])
y_comp[cat].extend(y[cidx][cat])
for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat])

@ -413,7 +413,7 @@ class LearningSolver:
logger.warning("Empty list of training instances provided. Skipping.")
return
Component.fit_multiple(
self.components,
list(self.components.values()),
training_instances,
n_jobs=n_jobs,
)

@ -58,7 +58,7 @@ def test_learning_solver(
assert after_lp.lp_solve.lp_log is not None
assert len(after_lp.lp_solve.lp_log) > 100
solver.fit([instance])
solver.fit([instance], n_jobs=4)
solver.solve(instance)
# Assert solver is picklable

@ -28,7 +28,7 @@ def test_benchmark() -> None:
"Strategy B": LearningSolver(),
}
benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances) # type: ignore
benchmark.fit(train_instances, n_jobs=n_jobs) # type: ignore
benchmark.parallel_solve(
test_instances, # type: ignore
n_jobs=n_jobs,

Loading…
Cancel
Save