diff --git a/miplearn/benchmark.py b/miplearn/benchmark.py index c741598..0146772 100644 --- a/miplearn/benchmark.py +++ b/miplearn/benchmark.py @@ -27,7 +27,8 @@ class BenchmarkRunner: for (name, solver) in self.solvers.items(): results = solver.parallel_solve(instances, n_jobs=n_jobs, - label=name) + label=name, + collect_training_data=False) for i in range(len(instances)): wallclock_time = None for key in ["Time", "Wall time", "Wallclock time"]: diff --git a/miplearn/solvers.py b/miplearn/solvers.py index c8b46b8..47144c8 100644 --- a/miplearn/solvers.py +++ b/miplearn/solvers.py @@ -55,9 +55,6 @@ class LearningSolver: self.internal_solver = self.internal_solver_factory() self.is_persistent = hasattr(self.internal_solver, "set_instance") - def _clear(self): - self.internal_solver = None - def solve(self, instance, tee=False): model = instance.to_model() @@ -80,13 +77,20 @@ class LearningSolver: return solve_results - def parallel_solve(self, instances, n_jobs=4, label="Solve"): - self._clear() + def parallel_solve(self, + instances, + n_jobs=4, + label="Solve", + collect_training_data=True, + ): + self.internal_solver = None def _process(instance): solver = deepcopy(self) results = solver.solve(instance) - solver._clear() + solver.internal_solver = None + if not collect_training_data: + solver.components = {} return solver, results solver_result_pairs = Parallel(n_jobs=n_jobs)(