mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make progress bars optional; other minor fixes
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
from typing import Any, List, TYPE_CHECKING, Tuple, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from p_tqdm import p_umap
|
||||
|
||||
from miplearn.features.sample import Sample
|
||||
@@ -186,6 +187,7 @@ class Component:
|
||||
components: List["Component"],
|
||||
instances: List[Instance],
|
||||
n_jobs: int = 1,
|
||||
progress: bool = False,
|
||||
) -> None:
|
||||
|
||||
# Part I: Pre-fit
|
||||
@@ -203,7 +205,13 @@ class Component:
|
||||
if n_jobs == 1:
|
||||
pre = [_pre_sample_xy(instance) for instance in instances]
|
||||
else:
|
||||
pre = p_umap(_pre_sample_xy, instances, num_cpus=n_jobs)
|
||||
pre = p_umap(
|
||||
_pre_sample_xy,
|
||||
instances,
|
||||
num_cpus=n_jobs,
|
||||
desc="pre-sample-xy",
|
||||
disable=not progress,
|
||||
)
|
||||
pre_combined: Dict = {}
|
||||
for (cidx, comp) in enumerate(components):
|
||||
pre_combined[cidx] = []
|
||||
@@ -237,8 +245,15 @@ class Component:
|
||||
if n_jobs == 1:
|
||||
xy = [_sample_xy(instance) for instance in instances]
|
||||
else:
|
||||
xy = p_umap(_sample_xy, instances)
|
||||
for (cidx, comp) in enumerate(components):
|
||||
xy = p_umap(_sample_xy, instances, desc="sample-xy", disable=not progress)
|
||||
|
||||
for (cidx, comp) in enumerate(
|
||||
tqdm(
|
||||
components,
|
||||
desc="fit",
|
||||
disable=not progress,
|
||||
)
|
||||
):
|
||||
x_comp: Dict = {}
|
||||
y_comp: Dict = {}
|
||||
for (x, y) in xy:
|
||||
|
||||
@@ -102,6 +102,8 @@ class DynamicConstraintsComponent(Component):
|
||||
assert pre is not None
|
||||
known_cids: Set = set()
|
||||
for cids in pre:
|
||||
if cids is None:
|
||||
continue
|
||||
known_cids |= set(list(cids))
|
||||
self.known_cids.clear()
|
||||
self.known_cids.extend(sorted(known_cids))
|
||||
|
||||
Reference in New Issue
Block a user