Branch: add progress bar

This commit is contained in:
2020-04-24 12:42:20 -05:00
parent 80f2251877
commit e352e478d0
2 changed files with 3 additions and 3 deletions

View File

@@ -66,8 +66,8 @@ def train():
solver = LearningSolver(time_limit=train_time_limit, solver = LearningSolver(time_limit=train_time_limit,
solver=internal_solver) solver=internal_solver)
solver.add(BranchPriorityComponent()) solver.add(BranchPriorityComponent())
solver.parallel_solve(train_instances[:1], n_jobs=n_jobs) solver.parallel_solve(train_instances, n_jobs=n_jobs)
solver.fit(train_instances[:1]) solver.fit(train_instances)
save(train_instances, "%s/train_instances.bin" % basepath) save(train_instances, "%s/train_instances.bin" % basepath)
save(test_instances, "%s/test_instances.bin" % basepath) save(test_instances, "%s/test_instances.bin" % basepath)

View File

@@ -55,7 +55,7 @@ class BranchPriorityComponent(Component):
pass pass
def fit(self, training_instances, n_jobs=1): def fit(self, training_instances, n_jobs=1):
for instance in training_instances: for instance in tqdm(training_instances, desc="Fit (branch)"):
if not hasattr(instance, "branch_priorities"): if not hasattr(instance, "branch_priorities"):
instance.branch_priorities = self.compute_priorities(instance) instance.branch_priorities = self.compute_priorities(instance)
x, y = self.x(training_instances), self.y(training_instances) x, y = self.x(training_instances), self.y(training_instances)