mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Disable progress bars when stdout it not a TTY
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
@@ -25,7 +26,8 @@ class BranchPriorityExtractor(Extractor):
|
||||
result = {}
|
||||
for instance in tqdm(instances,
|
||||
desc="Extract (branch)",
|
||||
disable=len(instances) < 5):
|
||||
disable=len(instances) < 5 or (not sys.stdout.isatty()),
|
||||
):
|
||||
var_split = self.split_variables(instance)
|
||||
for (category, var_index_pairs) in var_split.items():
|
||||
if category not in result:
|
||||
@@ -55,7 +57,10 @@ class BranchPriorityComponent(Component):
|
||||
pass
|
||||
|
||||
def fit(self, training_instances, n_jobs=1):
|
||||
for instance in tqdm(training_instances, desc="Fit (branch)"):
|
||||
for instance in tqdm(training_instances,
|
||||
desc="Fit (branch)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
if not hasattr(instance, "branch_priorities"):
|
||||
instance.branch_priorities = self.compute_priorities(instance)
|
||||
x, y = self.x(training_instances), self.y(training_instances)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
@@ -52,7 +53,10 @@ class UserCutsComponent(Component):
|
||||
violation_to_instance_idx[v] = []
|
||||
violation_to_instance_idx[v] += [idx]
|
||||
|
||||
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (user cuts)"):
|
||||
for (v, classifier) in tqdm(self.classifiers.items(),
|
||||
desc="Fit (user cuts)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
logger.debug("Training: %s" % (str(v)))
|
||||
label = np.zeros(len(training_instances))
|
||||
label[violation_to_instance_idx[v]] = 1.0
|
||||
@@ -72,7 +76,10 @@ class UserCutsComponent(Component):
|
||||
all_violations = set()
|
||||
for instance in instances:
|
||||
all_violations |= set(instance.found_violated_user_cuts)
|
||||
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
|
||||
for idx in tqdm(range(len(instances)),
|
||||
desc="Evaluate (lazy)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
instance = instances[idx]
|
||||
condition_positive = set(instance.found_violated_user_cuts)
|
||||
condition_negative = all_violations - condition_positive
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
@@ -47,12 +48,17 @@ class LazyConstraintsComponent(Component):
|
||||
violation_to_instance_idx = {}
|
||||
for (idx, instance) in enumerate(training_instances):
|
||||
for v in instance.found_violated_lazy_constraints:
|
||||
if isinstance(v, list):
|
||||
v = tuple(v)
|
||||
if v not in self.classifiers:
|
||||
self.classifiers[v] = deepcopy(self.classifier_prototype)
|
||||
violation_to_instance_idx[v] = []
|
||||
violation_to_instance_idx[v] += [idx]
|
||||
|
||||
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (lazy)"):
|
||||
for (v, classifier) in tqdm(self.classifiers.items(),
|
||||
desc="Fit (lazy)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
logger.debug("Training: %s" % (str(v)))
|
||||
label = np.zeros(len(training_instances))
|
||||
label[violation_to_instance_idx[v]] = 1.0
|
||||
@@ -72,7 +78,10 @@ class LazyConstraintsComponent(Component):
|
||||
all_violations = set()
|
||||
for instance in instances:
|
||||
all_violations |= set(instance.found_violated_lazy_constraints)
|
||||
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
|
||||
for idx in tqdm(range(len(instances)),
|
||||
desc="Evaluate (lazy)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
instance = instances[idx]
|
||||
condition_positive = set(instance.found_violated_lazy_constraints)
|
||||
condition_negative = all_violations - condition_positive
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from copy import deepcopy
|
||||
import sys
|
||||
|
||||
from .component import Component
|
||||
from ..classifiers.adaptive import AdaptiveClassifier
|
||||
@@ -49,7 +50,10 @@ class PrimalSolutionComponent(Component):
|
||||
features = VariableFeaturesExtractor().extract(training_instances)
|
||||
solutions = SolutionExtractor().extract(training_instances)
|
||||
|
||||
for category in tqdm(features.keys(), desc="Fit (primal)"):
|
||||
for category in tqdm(features.keys(),
|
||||
desc="Fit (primal)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
x_train = features[category]
|
||||
for label in [0, 1]:
|
||||
y_train = solutions[category][:, label].astype(int)
|
||||
@@ -104,7 +108,9 @@ class PrimalSolutionComponent(Component):
|
||||
ev = {"Fix zero": {},
|
||||
"Fix one": {}}
|
||||
for instance_idx in tqdm(range(len(instances)),
|
||||
desc="Evaluate (primal)"):
|
||||
desc="Evaluate (primal)",
|
||||
disable=not sys.stdout.isatty(),
|
||||
):
|
||||
instance = instances[instance_idx]
|
||||
solution_actual = instance.solution
|
||||
solution_pred = self.predict(instance)
|
||||
|
||||
Reference in New Issue
Block a user