Disable progress bars when stdout it not a TTY

pull/3/head
Alinson S. Xavier 5 years ago
parent 0658bf2722
commit 302372847e

@ -4,6 +4,7 @@
import logging import logging
import os import os
import sys
import subprocess import subprocess
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
@ -25,7 +26,8 @@ class BranchPriorityExtractor(Extractor):
result = {} result = {}
for instance in tqdm(instances, for instance in tqdm(instances,
desc="Extract (branch)", desc="Extract (branch)",
disable=len(instances) < 5): disable=len(instances) < 5 or (not sys.stdout.isatty()),
):
var_split = self.split_variables(instance) var_split = self.split_variables(instance)
for (category, var_index_pairs) in var_split.items(): for (category, var_index_pairs) in var_split.items():
if category not in result: if category not in result:
@ -55,7 +57,10 @@ class BranchPriorityComponent(Component):
pass pass
def fit(self, training_instances, n_jobs=1): def fit(self, training_instances, n_jobs=1):
for instance in tqdm(training_instances, desc="Fit (branch)"): for instance in tqdm(training_instances,
desc="Fit (branch)",
disable=not sys.stdout.isatty(),
):
if not hasattr(instance, "branch_priorities"): if not hasattr(instance, "branch_priorities"):
instance.branch_priorities = self.compute_priorities(instance) instance.branch_priorities = self.compute_priorities(instance)
x, y = self.x(training_instances), self.y(training_instances) x, y = self.x(training_instances), self.y(training_instances)

@ -2,6 +2,7 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import sys
from copy import deepcopy from copy import deepcopy
from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.counting import CountingClassifier
@ -52,7 +53,10 @@ class UserCutsComponent(Component):
violation_to_instance_idx[v] = [] violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx] violation_to_instance_idx[v] += [idx]
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (user cuts)"): for (v, classifier) in tqdm(self.classifiers.items(),
desc="Fit (user cuts)",
disable=not sys.stdout.isatty(),
):
logger.debug("Training: %s" % (str(v))) logger.debug("Training: %s" % (str(v)))
label = np.zeros(len(training_instances)) label = np.zeros(len(training_instances))
label[violation_to_instance_idx[v]] = 1.0 label[violation_to_instance_idx[v]] = 1.0
@ -72,7 +76,10 @@ class UserCutsComponent(Component):
all_violations = set() all_violations = set()
for instance in instances: for instance in instances:
all_violations |= set(instance.found_violated_user_cuts) all_violations |= set(instance.found_violated_user_cuts)
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"): for idx in tqdm(range(len(instances)),
desc="Evaluate (lazy)",
disable=not sys.stdout.isatty(),
):
instance = instances[idx] instance = instances[idx]
condition_positive = set(instance.found_violated_user_cuts) condition_positive = set(instance.found_violated_user_cuts)
condition_negative = all_violations - condition_positive condition_negative = all_violations - condition_positive

@ -2,6 +2,7 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import sys
from copy import deepcopy from copy import deepcopy
from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.counting import CountingClassifier
@ -47,12 +48,17 @@ class LazyConstraintsComponent(Component):
violation_to_instance_idx = {} violation_to_instance_idx = {}
for (idx, instance) in enumerate(training_instances): for (idx, instance) in enumerate(training_instances):
for v in instance.found_violated_lazy_constraints: for v in instance.found_violated_lazy_constraints:
if isinstance(v, list):
v = tuple(v)
if v not in self.classifiers: if v not in self.classifiers:
self.classifiers[v] = deepcopy(self.classifier_prototype) self.classifiers[v] = deepcopy(self.classifier_prototype)
violation_to_instance_idx[v] = [] violation_to_instance_idx[v] = []
violation_to_instance_idx[v] += [idx] violation_to_instance_idx[v] += [idx]
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (lazy)"): for (v, classifier) in tqdm(self.classifiers.items(),
desc="Fit (lazy)",
disable=not sys.stdout.isatty(),
):
logger.debug("Training: %s" % (str(v))) logger.debug("Training: %s" % (str(v)))
label = np.zeros(len(training_instances)) label = np.zeros(len(training_instances))
label[violation_to_instance_idx[v]] = 1.0 label[violation_to_instance_idx[v]] = 1.0
@ -72,7 +78,10 @@ class LazyConstraintsComponent(Component):
all_violations = set() all_violations = set()
for instance in instances: for instance in instances:
all_violations |= set(instance.found_violated_lazy_constraints) all_violations |= set(instance.found_violated_lazy_constraints)
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"): for idx in tqdm(range(len(instances)),
desc="Evaluate (lazy)",
disable=not sys.stdout.isatty(),
):
instance = instances[idx] instance = instances[idx]
condition_positive = set(instance.found_violated_lazy_constraints) condition_positive = set(instance.found_violated_lazy_constraints)
condition_negative = all_violations - condition_positive condition_negative = all_violations - condition_positive

@ -3,6 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from copy import deepcopy from copy import deepcopy
import sys
from .component import Component from .component import Component
from ..classifiers.adaptive import AdaptiveClassifier from ..classifiers.adaptive import AdaptiveClassifier
@ -49,7 +50,10 @@ class PrimalSolutionComponent(Component):
features = VariableFeaturesExtractor().extract(training_instances) features = VariableFeaturesExtractor().extract(training_instances)
solutions = SolutionExtractor().extract(training_instances) solutions = SolutionExtractor().extract(training_instances)
for category in tqdm(features.keys(), desc="Fit (primal)"): for category in tqdm(features.keys(),
desc="Fit (primal)",
disable=not sys.stdout.isatty(),
):
x_train = features[category] x_train = features[category]
for label in [0, 1]: for label in [0, 1]:
y_train = solutions[category][:, label].astype(int) y_train = solutions[category][:, label].astype(int)
@ -104,7 +108,9 @@ class PrimalSolutionComponent(Component):
ev = {"Fix zero": {}, ev = {"Fix zero": {},
"Fix one": {}} "Fix one": {}}
for instance_idx in tqdm(range(len(instances)), for instance_idx in tqdm(range(len(instances)),
desc="Evaluate (primal)"): desc="Evaluate (primal)",
disable=not sys.stdout.isatty(),
):
instance = instances[instance_idx] instance = instances[instance_idx]
solution_actual = instance.solution solution_actual = instance.solution
solution_pred = self.predict(instance) solution_pred = self.predict(instance)

Loading…
Cancel
Save