diff --git a/miplearn/benchmark.py b/miplearn/benchmark.py index da59a69..9ba6995 100644 --- a/miplearn/benchmark.py +++ b/miplearn/benchmark.py @@ -8,6 +8,7 @@ from typing import Dict, List import pandas as pd +from miplearn.components.component import Component from miplearn.instance.base import Instance from miplearn.solvers.learning import LearningSolver @@ -106,10 +107,17 @@ class BenchmarkRunner: ---------- instances: List[Instance] List of training instances. + n_jobs: int + Number of parallel processes to use. """ - for (solver_name, solver) in self.solvers.items(): - logger.debug(f"Fitting {solver_name}...") - solver.fit(instances, n_jobs=n_jobs) + components: List[Component] = [] + for solver in self.solvers.values(): + components += solver.components.values() + Component.fit_multiple( + components, + instances, + n_jobs=n_jobs, + ) def _silence_miplearn_logger(self) -> None: miplearn_logger = logging.getLogger("miplearn") diff --git a/miplearn/components/component.py b/miplearn/components/component.py index 75e9017..c02628a 100644 --- a/miplearn/components/component.py +++ b/miplearn/components/component.py @@ -183,18 +183,19 @@ class Component: @staticmethod def fit_multiple( - components: Dict[str, "Component"], + components: List["Component"], instances: List[Instance], n_jobs: int = 1, ) -> None: + # Part I: Pre-fit def _pre_sample_xy(instance: Instance) -> Dict: pre_instance: Dict = {} - for (cname, comp) in components.items(): - pre_instance[cname] = [] + for (cidx, comp) in enumerate(components): + pre_instance[cidx] = [] instance.load() for sample in instance.samples: - for (cname, comp) in components.items(): - pre_instance[cname].append(comp.pre_sample_xy(instance, sample)) + for (cidx, comp) in enumerate(components): + pre_instance[cidx].append(comp.pre_sample_xy(instance, sample)) instance.free() return pre_instance @@ -203,25 +204,25 @@ class Component: else: pre = p_umap(_pre_sample_xy, instances, num_cpus=n_jobs) pre_combined: Dict = {} - for (cname, comp) in components.items(): - pre_combined[cname] = [] + for (cidx, comp) in enumerate(components): + pre_combined[cidx] = [] for p in pre: - pre_combined[cname].extend(p[cname]) - - for (cname, comp) in components.items(): - comp.pre_fit(pre_combined[cname]) + pre_combined[cidx].extend(p[cidx]) + for (cidx, comp) in enumerate(components): + comp.pre_fit(pre_combined[cidx]) + # Part II: Fit def _sample_xy(instance: Instance) -> Tuple[Dict, Dict]: x_instance: Dict = {} y_instance: Dict = {} - for (cname, comp) in components.items(): - x_instance[cname] = {} - y_instance[cname] = {} + for (cidx, comp) in enumerate(components): + x_instance[cidx] = {} + y_instance[cidx] = {} instance.load() for sample in instance.samples: - for (cname, comp) in components.items(): - x = x_instance[cname] - y = y_instance[cname] + for (cidx, comp) in enumerate(components): + x = x_instance[cidx] + y = y_instance[cidx] x_sample, y_sample = comp.sample_xy(instance, sample) for cat in x_sample.keys(): if cat not in x: @@ -236,17 +237,16 @@ class Component: xy = [_sample_xy(instance) for instance in instances] else: xy = p_umap(_sample_xy, instances) - - for (cname, comp) in components.items(): + for (cidx, comp) in enumerate(components): x_comp: Dict = {} y_comp: Dict = {} for (x, y) in xy: - for cat in x[cname].keys(): + for cat in x[cidx].keys(): if cat not in x_comp: x_comp[cat] = [] y_comp[cat] = [] - x_comp[cat].extend(x[cname][cat]) - y_comp[cat].extend(y[cname][cat]) + x_comp[cat].extend(x[cidx][cat]) + y_comp[cat].extend(y[cidx][cat]) for cat in x_comp.keys(): x_comp[cat] = np.array(x_comp[cat], dtype=np.float32) y_comp[cat] = np.array(y_comp[cat]) diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 0f0926c..5f8183a 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -413,7 +413,7 @@ class LearningSolver: logger.warning("Empty list of training instances provided. Skipping.") return Component.fit_multiple( - self.components, + list(self.components.values()), training_instances, n_jobs=n_jobs, ) diff --git a/tests/solvers/test_learning_solver.py b/tests/solvers/test_learning_solver.py index 85b33a8..8c51a17 100644 --- a/tests/solvers/test_learning_solver.py +++ b/tests/solvers/test_learning_solver.py @@ -58,7 +58,7 @@ def test_learning_solver( assert after_lp.lp_solve.lp_log is not None assert len(after_lp.lp_solve.lp_log) > 100 - solver.fit([instance]) + solver.fit([instance], n_jobs=4) solver.solve(instance) # Assert solver is picklable diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 1183aa4..ad72bf4 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -28,7 +28,7 @@ def test_benchmark() -> None: "Strategy B": LearningSolver(), } benchmark = BenchmarkRunner(test_solvers) - benchmark.fit(train_instances) # type: ignore + benchmark.fit(train_instances, n_jobs=n_jobs) # type: ignore benchmark.parallel_solve( test_instances, # type: ignore n_jobs=n_jobs,