Make progress bar labels consistent

pull/3/head
Alinson S. Xavier 6 years ago
parent 4f39dc02d6
commit 4e132a7677

@ -104,7 +104,7 @@ class BranchPriorityComponent(Component):
# Run strong branching on pending instances # Run strong branching on pending instances
subcomponents = Parallel(n_jobs=n_jobs)( subcomponents = Parallel(n_jobs=n_jobs)(
delayed(_process)(instance) delayed(_process)(instance)
for instance in tqdm(self.pending_instances, desc="Branch priority") for instance in tqdm(self.pending_instances, desc="Fit (branch)")
) )
self.merge(subcomponents) self.merge(subcomponents)
self.pending_instances.clear() self.pending_instances.clear()

@ -50,7 +50,7 @@ class PrimalSolutionComponent(Component):
features = VariableFeaturesExtractor().extract(training_instances) features = VariableFeaturesExtractor().extract(training_instances)
solutions = SolutionExtractor().extract(training_instances) solutions = SolutionExtractor().extract(training_instances)
for category in tqdm(features.keys(), desc="Fit (Primal)"): for category in tqdm(features.keys(), desc="Fit (primal)"):
x_train = features[category] x_train = features[category]
y_train = solutions[category] y_train = solutions[category]
for label in [0, 1]: for label in [0, 1]:
@ -116,7 +116,8 @@ class PrimalSolutionComponent(Component):
def evaluate(self, instances): def evaluate(self, instances):
ev = {"Fix zero": {}, ev = {"Fix zero": {},
"Fix one": {}} "Fix one": {}}
for instance_idx in tqdm(range(len(instances))): for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)"):
instance = instances[instance_idx] instance = instances[instance_idx]
solution_actual = instance.solution solution_actual = instance.solution
solution_pred = self.predict(instance) solution_pred = self.predict(instance)

@ -35,7 +35,7 @@ class VariableFeaturesExtractor(Extractor):
def extract(self, instances): def extract(self, instances):
result = {} result = {}
for instance in tqdm(instances, for instance in tqdm(instances,
desc="Extract var features", desc="Extract (vars)",
disable=len(instances) < 5): disable=len(instances) < 5):
instance_features = instance.get_instance_features() instance_features = instance.get_instance_features()
var_split = self.split_variables(instance) var_split = self.split_variables(instance)
@ -60,7 +60,7 @@ class SolutionExtractor(Extractor):
def extract(self, instances): def extract(self, instances):
result = {} result = {}
for instance in tqdm(instances, for instance in tqdm(instances,
desc="Extract solution", desc="Extract (solution)",
disable=len(instances) < 5): disable=len(instances) < 5):
var_split = self.split_variables(instance) var_split = self.split_variables(instance)
for (category, var_index_pairs) in var_split.items(): for (category, var_index_pairs) in var_split.items():

Loading…
Cancel
Save