mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Disable progress bars when stdout it not a TTY
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -25,7 +26,8 @@ class BranchPriorityExtractor(Extractor):
|
|||||||
result = {}
|
result = {}
|
||||||
for instance in tqdm(instances,
|
for instance in tqdm(instances,
|
||||||
desc="Extract (branch)",
|
desc="Extract (branch)",
|
||||||
disable=len(instances) < 5):
|
disable=len(instances) < 5 or (not sys.stdout.isatty()),
|
||||||
|
):
|
||||||
var_split = self.split_variables(instance)
|
var_split = self.split_variables(instance)
|
||||||
for (category, var_index_pairs) in var_split.items():
|
for (category, var_index_pairs) in var_split.items():
|
||||||
if category not in result:
|
if category not in result:
|
||||||
@@ -55,7 +57,10 @@ class BranchPriorityComponent(Component):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def fit(self, training_instances, n_jobs=1):
|
def fit(self, training_instances, n_jobs=1):
|
||||||
for instance in tqdm(training_instances, desc="Fit (branch)"):
|
for instance in tqdm(training_instances,
|
||||||
|
desc="Fit (branch)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
if not hasattr(instance, "branch_priorities"):
|
if not hasattr(instance, "branch_priorities"):
|
||||||
instance.branch_priorities = self.compute_priorities(instance)
|
instance.branch_priorities = self.compute_priorities(instance)
|
||||||
x, y = self.x(training_instances), self.y(training_instances)
|
x, y = self.x(training_instances), self.y(training_instances)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from miplearn.classifiers.counting import CountingClassifier
|
from miplearn.classifiers.counting import CountingClassifier
|
||||||
@@ -52,7 +53,10 @@ class UserCutsComponent(Component):
|
|||||||
violation_to_instance_idx[v] = []
|
violation_to_instance_idx[v] = []
|
||||||
violation_to_instance_idx[v] += [idx]
|
violation_to_instance_idx[v] += [idx]
|
||||||
|
|
||||||
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (user cuts)"):
|
for (v, classifier) in tqdm(self.classifiers.items(),
|
||||||
|
desc="Fit (user cuts)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
logger.debug("Training: %s" % (str(v)))
|
logger.debug("Training: %s" % (str(v)))
|
||||||
label = np.zeros(len(training_instances))
|
label = np.zeros(len(training_instances))
|
||||||
label[violation_to_instance_idx[v]] = 1.0
|
label[violation_to_instance_idx[v]] = 1.0
|
||||||
@@ -72,7 +76,10 @@ class UserCutsComponent(Component):
|
|||||||
all_violations = set()
|
all_violations = set()
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
all_violations |= set(instance.found_violated_user_cuts)
|
all_violations |= set(instance.found_violated_user_cuts)
|
||||||
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
|
for idx in tqdm(range(len(instances)),
|
||||||
|
desc="Evaluate (lazy)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
instance = instances[idx]
|
instance = instances[idx]
|
||||||
condition_positive = set(instance.found_violated_user_cuts)
|
condition_positive = set(instance.found_violated_user_cuts)
|
||||||
condition_negative = all_violations - condition_positive
|
condition_negative = all_violations - condition_positive
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
import sys
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from miplearn.classifiers.counting import CountingClassifier
|
from miplearn.classifiers.counting import CountingClassifier
|
||||||
@@ -47,12 +48,17 @@ class LazyConstraintsComponent(Component):
|
|||||||
violation_to_instance_idx = {}
|
violation_to_instance_idx = {}
|
||||||
for (idx, instance) in enumerate(training_instances):
|
for (idx, instance) in enumerate(training_instances):
|
||||||
for v in instance.found_violated_lazy_constraints:
|
for v in instance.found_violated_lazy_constraints:
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = tuple(v)
|
||||||
if v not in self.classifiers:
|
if v not in self.classifiers:
|
||||||
self.classifiers[v] = deepcopy(self.classifier_prototype)
|
self.classifiers[v] = deepcopy(self.classifier_prototype)
|
||||||
violation_to_instance_idx[v] = []
|
violation_to_instance_idx[v] = []
|
||||||
violation_to_instance_idx[v] += [idx]
|
violation_to_instance_idx[v] += [idx]
|
||||||
|
|
||||||
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (lazy)"):
|
for (v, classifier) in tqdm(self.classifiers.items(),
|
||||||
|
desc="Fit (lazy)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
logger.debug("Training: %s" % (str(v)))
|
logger.debug("Training: %s" % (str(v)))
|
||||||
label = np.zeros(len(training_instances))
|
label = np.zeros(len(training_instances))
|
||||||
label[violation_to_instance_idx[v]] = 1.0
|
label[violation_to_instance_idx[v]] = 1.0
|
||||||
@@ -72,7 +78,10 @@ class LazyConstraintsComponent(Component):
|
|||||||
all_violations = set()
|
all_violations = set()
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
all_violations |= set(instance.found_violated_lazy_constraints)
|
all_violations |= set(instance.found_violated_lazy_constraints)
|
||||||
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
|
for idx in tqdm(range(len(instances)),
|
||||||
|
desc="Evaluate (lazy)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
instance = instances[idx]
|
instance = instances[idx]
|
||||||
condition_positive = set(instance.found_violated_lazy_constraints)
|
condition_positive = set(instance.found_violated_lazy_constraints)
|
||||||
condition_negative = all_violations - condition_positive
|
condition_negative = all_violations - condition_positive
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
import sys
|
||||||
|
|
||||||
from .component import Component
|
from .component import Component
|
||||||
from ..classifiers.adaptive import AdaptiveClassifier
|
from ..classifiers.adaptive import AdaptiveClassifier
|
||||||
@@ -49,7 +50,10 @@ class PrimalSolutionComponent(Component):
|
|||||||
features = VariableFeaturesExtractor().extract(training_instances)
|
features = VariableFeaturesExtractor().extract(training_instances)
|
||||||
solutions = SolutionExtractor().extract(training_instances)
|
solutions = SolutionExtractor().extract(training_instances)
|
||||||
|
|
||||||
for category in tqdm(features.keys(), desc="Fit (primal)"):
|
for category in tqdm(features.keys(),
|
||||||
|
desc="Fit (primal)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
x_train = features[category]
|
x_train = features[category]
|
||||||
for label in [0, 1]:
|
for label in [0, 1]:
|
||||||
y_train = solutions[category][:, label].astype(int)
|
y_train = solutions[category][:, label].astype(int)
|
||||||
@@ -104,7 +108,9 @@ class PrimalSolutionComponent(Component):
|
|||||||
ev = {"Fix zero": {},
|
ev = {"Fix zero": {},
|
||||||
"Fix one": {}}
|
"Fix one": {}}
|
||||||
for instance_idx in tqdm(range(len(instances)),
|
for instance_idx in tqdm(range(len(instances)),
|
||||||
desc="Evaluate (primal)"):
|
desc="Evaluate (primal)",
|
||||||
|
disable=not sys.stdout.isatty(),
|
||||||
|
):
|
||||||
instance = instances[instance_idx]
|
instance = instances[instance_idx]
|
||||||
solution_actual = instance.solution
|
solution_actual = instance.solution
|
||||||
solution_pred = self.predict(instance)
|
solution_pred = self.predict(instance)
|
||||||
|
|||||||
Reference in New Issue
Block a user