Add n_jobs to BenchmarkRunner.fit

master
Alinson S. Xavier 5 years ago
parent 77b10b9609
commit e1f32b1798
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -98,7 +98,7 @@ class BenchmarkRunner:
os.makedirs(os.path.dirname(filename), exist_ok=True)
self.results.to_csv(filename)
def fit(self, instances: List[Instance]) -> None:
def fit(self, instances: List[Instance], n_jobs: int = 1) -> None:
"""
Trains all solvers with the provided training instances.
@ -109,7 +109,7 @@ class BenchmarkRunner:
"""
for (solver_name, solver) in self.solvers.items():
logger.debug(f"Fitting {solver_name}...")
solver.fit(instances)
solver.fit(instances, n_jobs=n_jobs)
def _silence_miplearn_logger(self) -> None:
miplearn_logger = logging.getLogger("miplearn")

Loading…
Cancel
Save