Add n_jobs to BenchmarkRunner.fit

master
Alinson S. Xavier 5 years ago
parent 77b10b9609
commit e1f32b1798
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -98,7 +98,7 @@ class BenchmarkRunner:
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)
self.results.to_csv(filename) self.results.to_csv(filename)
def fit(self, instances: List[Instance]) -> None: def fit(self, instances: List[Instance], n_jobs: int = 1) -> None:
""" """
Trains all solvers with the provided training instances. Trains all solvers with the provided training instances.
@ -109,7 +109,7 @@ class BenchmarkRunner:
""" """
for (solver_name, solver) in self.solvers.items(): for (solver_name, solver) in self.solvers.items():
logger.debug(f"Fitting {solver_name}...") logger.debug(f"Fitting {solver_name}...")
solver.fit(instances) solver.fit(instances, n_jobs=n_jobs)
def _silence_miplearn_logger(self) -> None: def _silence_miplearn_logger(self) -> None:
miplearn_logger = logging.getLogger("miplearn") miplearn_logger = logging.getLogger("miplearn")

Loading…
Cancel
Save