Disable progress bars when stdout it not a TTY

pull/3/head
Alinson S. Xavier 5 years ago
parent 0658bf2722
commit 302372847e

@ -4,6 +4,7 @@
import logging
import os
import sys
import subprocess
import tempfile
from copy import deepcopy
@ -25,7 +26,8 @@ class BranchPriorityExtractor(Extractor):
result = {}
for instance in tqdm(instances,
desc="Extract (branch)",
disable=len(instances) < 5):
disable=len(instances) < 5 or (not sys.stdout.isatty()),
):
var_split = self.split_variables(instance)
for (category, var_index_pairs) in var_split.items():
if category not in result:
@ -55,7 +57,10 @@ class BranchPriorityComponent(Component):
pass
def fit(self, training_instances, n_jobs=1):
for instance in tqdm(training_instances, desc="Fit (branch)"):
for instance in tqdm(training_instances,
desc="Fit (branch)",
disable=not sys.stdout.isatty(),
):
if not hasattr(instance, "branch_priorities"):
instance.branch_priorities = self.compute_priorities(instance)
x, y = self.x(training_instances), self.y(training_instances)

@ -2,6 +2,7 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import sys
from copy import deepcopy
from miplearn.classifiers.counting import CountingClassifier
@ -52,7 +53,10 @@ class UserCutsComponent(Component):
violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx]
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (user cuts)"):
for (v, classifier) in tqdm(self.classifiers.items(),
desc="Fit (user cuts)",
disable=not sys.stdout.isatty(),
):
logger.debug("Training: %s" % (str(v)))
label = np.zeros(len(training_instances))
label[violation_to_instance_idx[v]] = 1.0
@ -72,7 +76,10 @@ class UserCutsComponent(Component):
all_violations = set()
for instance in instances:
all_violations |= set(instance.found_violated_user_cuts)
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
for idx in tqdm(range(len(instances)),
desc="Evaluate (lazy)",
disable=not sys.stdout.isatty(),
):
instance = instances[idx]
condition_positive = set(instance.found_violated_user_cuts)
condition_negative = all_violations - condition_positive

@ -2,6 +2,7 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import sys
from copy import deepcopy
from miplearn.classifiers.counting import CountingClassifier
@ -47,12 +48,17 @@ class LazyConstraintsComponent(Component):
violation_to_instance_idx = {}
for (idx, instance) in enumerate(training_instances):
for v in instance.found_violated_lazy_constraints:
if isinstance(v, list):
v = tuple(v)
if v not in self.classifiers:
self.classifiers[v] = deepcopy(self.classifier_prototype)
violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx]
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (lazy)"):
for (v, classifier) in tqdm(self.classifiers.items(),
desc="Fit (lazy)",
disable=not sys.stdout.isatty(),
):
logger.debug("Training: %s" % (str(v)))
label = np.zeros(len(training_instances))
label[violation_to_instance_idx[v]] = 1.0
@ -72,7 +78,10 @@ class LazyConstraintsComponent(Component):
all_violations = set()
for instance in instances:
all_violations |= set(instance.found_violated_lazy_constraints)
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
for idx in tqdm(range(len(instances)),
desc="Evaluate (lazy)",
disable=not sys.stdout.isatty(),
):
instance = instances[idx]
condition_positive = set(instance.found_violated_lazy_constraints)
condition_negative = all_violations - condition_positive

@ -3,6 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from copy import deepcopy
import sys
from .component import Component
from ..classifiers.adaptive import AdaptiveClassifier
@ -49,7 +50,10 @@ class PrimalSolutionComponent(Component):
features = VariableFeaturesExtractor().extract(training_instances)
solutions = SolutionExtractor().extract(training_instances)
for category in tqdm(features.keys(), desc="Fit (primal)"):
for category in tqdm(features.keys(),
desc="Fit (primal)",
disable=not sys.stdout.isatty(),
):
x_train = features[category]
for label in [0, 1]:
y_train = solutions[category][:, label].astype(int)
@ -104,7 +108,9 @@ class PrimalSolutionComponent(Component):
ev = {"Fix zero": {},
"Fix one": {}}
for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)"):
desc="Evaluate (primal)",
disable=not sys.stdout.isatty(),
):
instance = instances[instance_idx]
solution_actual = instance.solution
solution_pred = self.predict(instance)

Loading…
Cancel
Save