mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Do not collect training data during benchmarks
This commit is contained in:
@@ -27,7 +27,8 @@ class BenchmarkRunner:
|
||||
for (name, solver) in self.solvers.items():
|
||||
results = solver.parallel_solve(instances,
|
||||
n_jobs=n_jobs,
|
||||
label=name)
|
||||
label=name,
|
||||
collect_training_data=False)
|
||||
for i in range(len(instances)):
|
||||
wallclock_time = None
|
||||
for key in ["Time", "Wall time", "Wallclock time"]:
|
||||
|
||||
@@ -55,9 +55,6 @@ class LearningSolver:
|
||||
self.internal_solver = self.internal_solver_factory()
|
||||
self.is_persistent = hasattr(self.internal_solver, "set_instance")
|
||||
|
||||
def _clear(self):
|
||||
self.internal_solver = None
|
||||
|
||||
def solve(self, instance, tee=False):
|
||||
model = instance.to_model()
|
||||
|
||||
@@ -80,13 +77,20 @@ class LearningSolver:
|
||||
|
||||
return solve_results
|
||||
|
||||
def parallel_solve(self, instances, n_jobs=4, label="Solve"):
|
||||
self._clear()
|
||||
def parallel_solve(self,
|
||||
instances,
|
||||
n_jobs=4,
|
||||
label="Solve",
|
||||
collect_training_data=True,
|
||||
):
|
||||
self.internal_solver = None
|
||||
|
||||
def _process(instance):
|
||||
solver = deepcopy(self)
|
||||
results = solver.solve(instance)
|
||||
solver._clear()
|
||||
solver.internal_solver = None
|
||||
if not collect_training_data:
|
||||
solver.components = {}
|
||||
return solver, results
|
||||
|
||||
solver_result_pairs = Parallel(n_jobs=n_jobs)(
|
||||
|
||||
Reference in New Issue
Block a user