BenchmarkRunner.fit: Only iterate through files twice

master
Alinson S. Xavier 4 years ago
parent aaef8b8fb3
commit 46a7d3fe26
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -8,6 +8,7 @@ from typing import Dict, List
import pandas as pd import pandas as pd
from miplearn.components.component import Component
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.solvers.learning import LearningSolver from miplearn.solvers.learning import LearningSolver
@ -106,10 +107,17 @@ class BenchmarkRunner:
---------- ----------
instances: List[Instance] instances: List[Instance]
List of training instances. List of training instances.
n_jobs: int
Number of parallel processes to use.
""" """
for (solver_name, solver) in self.solvers.items(): components: List[Component] = []
logger.debug(f"Fitting {solver_name}...") for solver in self.solvers.values():
solver.fit(instances, n_jobs=n_jobs) components += solver.components.values()
Component.fit_multiple(
components,
instances,
n_jobs=n_jobs,
)
def _silence_miplearn_logger(self) -> None: def _silence_miplearn_logger(self) -> None:
miplearn_logger = logging.getLogger("miplearn") miplearn_logger = logging.getLogger("miplearn")

@ -183,18 +183,19 @@ class Component:
@staticmethod @staticmethod
def fit_multiple( def fit_multiple(
components: Dict[str, "Component"], components: List["Component"],
instances: List[Instance], instances: List[Instance],
n_jobs: int = 1, n_jobs: int = 1,
) -> None: ) -> None:
# Part I: Pre-fit
def _pre_sample_xy(instance: Instance) -> Dict: def _pre_sample_xy(instance: Instance) -> Dict:
pre_instance: Dict = {} pre_instance: Dict = {}
for (cname, comp) in components.items(): for (cidx, comp) in enumerate(components):
pre_instance[cname] = [] pre_instance[cidx] = []
instance.load() instance.load()
for sample in instance.samples: for sample in instance.samples:
for (cname, comp) in components.items(): for (cidx, comp) in enumerate(components):
pre_instance[cname].append(comp.pre_sample_xy(instance, sample)) pre_instance[cidx].append(comp.pre_sample_xy(instance, sample))
instance.free() instance.free()
return pre_instance return pre_instance
@ -203,25 +204,25 @@ class Component:
else: else:
pre = p_umap(_pre_sample_xy, instances, num_cpus=n_jobs) pre = p_umap(_pre_sample_xy, instances, num_cpus=n_jobs)
pre_combined: Dict = {} pre_combined: Dict = {}
for (cname, comp) in components.items(): for (cidx, comp) in enumerate(components):
pre_combined[cname] = [] pre_combined[cidx] = []
for p in pre: for p in pre:
pre_combined[cname].extend(p[cname]) pre_combined[cidx].extend(p[cidx])
for (cidx, comp) in enumerate(components):
for (cname, comp) in components.items(): comp.pre_fit(pre_combined[cidx])
comp.pre_fit(pre_combined[cname])
# Part II: Fit
def _sample_xy(instance: Instance) -> Tuple[Dict, Dict]: def _sample_xy(instance: Instance) -> Tuple[Dict, Dict]:
x_instance: Dict = {} x_instance: Dict = {}
y_instance: Dict = {} y_instance: Dict = {}
for (cname, comp) in components.items(): for (cidx, comp) in enumerate(components):
x_instance[cname] = {} x_instance[cidx] = {}
y_instance[cname] = {} y_instance[cidx] = {}
instance.load() instance.load()
for sample in instance.samples: for sample in instance.samples:
for (cname, comp) in components.items(): for (cidx, comp) in enumerate(components):
x = x_instance[cname] x = x_instance[cidx]
y = y_instance[cname] y = y_instance[cidx]
x_sample, y_sample = comp.sample_xy(instance, sample) x_sample, y_sample = comp.sample_xy(instance, sample)
for cat in x_sample.keys(): for cat in x_sample.keys():
if cat not in x: if cat not in x:
@ -236,17 +237,16 @@ class Component:
xy = [_sample_xy(instance) for instance in instances] xy = [_sample_xy(instance) for instance in instances]
else: else:
xy = p_umap(_sample_xy, instances) xy = p_umap(_sample_xy, instances)
for (cidx, comp) in enumerate(components):
for (cname, comp) in components.items():
x_comp: Dict = {} x_comp: Dict = {}
y_comp: Dict = {} y_comp: Dict = {}
for (x, y) in xy: for (x, y) in xy:
for cat in x[cname].keys(): for cat in x[cidx].keys():
if cat not in x_comp: if cat not in x_comp:
x_comp[cat] = [] x_comp[cat] = []
y_comp[cat] = [] y_comp[cat] = []
x_comp[cat].extend(x[cname][cat]) x_comp[cat].extend(x[cidx][cat])
y_comp[cat].extend(y[cname][cat]) y_comp[cat].extend(y[cidx][cat])
for cat in x_comp.keys(): for cat in x_comp.keys():
x_comp[cat] = np.array(x_comp[cat], dtype=np.float32) x_comp[cat] = np.array(x_comp[cat], dtype=np.float32)
y_comp[cat] = np.array(y_comp[cat]) y_comp[cat] = np.array(y_comp[cat])

@ -413,7 +413,7 @@ class LearningSolver:
logger.warning("Empty list of training instances provided. Skipping.") logger.warning("Empty list of training instances provided. Skipping.")
return return
Component.fit_multiple( Component.fit_multiple(
self.components, list(self.components.values()),
training_instances, training_instances,
n_jobs=n_jobs, n_jobs=n_jobs,
) )

@ -58,7 +58,7 @@ def test_learning_solver(
assert after_lp.lp_solve.lp_log is not None assert after_lp.lp_solve.lp_log is not None
assert len(after_lp.lp_solve.lp_log) > 100 assert len(after_lp.lp_solve.lp_log) > 100
solver.fit([instance]) solver.fit([instance], n_jobs=4)
solver.solve(instance) solver.solve(instance)
# Assert solver is picklable # Assert solver is picklable

@ -28,7 +28,7 @@ def test_benchmark() -> None:
"Strategy B": LearningSolver(), "Strategy B": LearningSolver(),
} }
benchmark = BenchmarkRunner(test_solvers) benchmark = BenchmarkRunner(test_solvers)
benchmark.fit(train_instances) # type: ignore benchmark.fit(train_instances, n_jobs=n_jobs) # type: ignore
benchmark.parallel_solve( benchmark.parallel_solve(
test_instances, # type: ignore test_instances, # type: ignore
n_jobs=n_jobs, n_jobs=n_jobs,

Loading…
Cancel
Save