mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Branch: add progress bar
This commit is contained in:
@@ -66,8 +66,8 @@ def train():
|
|||||||
solver = LearningSolver(time_limit=train_time_limit,
|
solver = LearningSolver(time_limit=train_time_limit,
|
||||||
solver=internal_solver)
|
solver=internal_solver)
|
||||||
solver.add(BranchPriorityComponent())
|
solver.add(BranchPriorityComponent())
|
||||||
solver.parallel_solve(train_instances[:1], n_jobs=n_jobs)
|
solver.parallel_solve(train_instances, n_jobs=n_jobs)
|
||||||
solver.fit(train_instances[:1])
|
solver.fit(train_instances)
|
||||||
save(train_instances, "%s/train_instances.bin" % basepath)
|
save(train_instances, "%s/train_instances.bin" % basepath)
|
||||||
save(test_instances, "%s/test_instances.bin" % basepath)
|
save(test_instances, "%s/test_instances.bin" % basepath)
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class BranchPriorityComponent(Component):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def fit(self, training_instances, n_jobs=1):
|
def fit(self, training_instances, n_jobs=1):
|
||||||
for instance in training_instances:
|
for instance in tqdm(training_instances, desc="Fit (branch)"):
|
||||||
if not hasattr(instance, "branch_priorities"):
|
if not hasattr(instance, "branch_priorities"):
|
||||||
instance.branch_priorities = self.compute_priorities(instance)
|
instance.branch_priorities = self.compute_priorities(instance)
|
||||||
x, y = self.x(training_instances), self.y(training_instances)
|
x, y = self.x(training_instances), self.y(training_instances)
|
||||||
|
|||||||
Reference in New Issue
Block a user