This commit is contained in:
2022-06-01 11:40:48 -05:00
parent 3fd252659e
commit 6cc253a903
15 changed files with 591 additions and 275 deletions

View File

@@ -14,6 +14,8 @@ from overrides import overrides
from miplearn.features.sample import Sample
from miplearn.instance.base import Instance
from miplearn.types import ConstraintName
from tqdm.auto import tqdm
from p_tqdm import p_umap
if TYPE_CHECKING:
from miplearn.solvers.learning import InternalSolver
@@ -155,13 +157,20 @@ def write_pickle_gz_multiple(objs: List[Any], dirname: str) -> None:
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
def save(objs: List[Any], dirname: str) -> List[str]:
def save(
objs: List[Any],
dirname: str,
progress: bool = False,
n_jobs: int = 1,
) -> List[str]:
"""
Saves the provided objects to gzipped pickled files. Files are named sequentially
as `dirname/00000.pkl.gz`, `dirname/00001.pkl.gz`, etc.
Parameters
----------
progress: bool
If True, show progress bar
objs: List[any]
List of files to save
dirname: str
@@ -171,11 +180,12 @@ def save(objs: List[Any], dirname: str) -> List[str]:
-------
List containing the relative paths of the saved files.
"""
filenames = []
for (i, obj) in enumerate(objs):
filename = f"{dirname}/{i:05d}.pkl.gz"
filenames.append(filename)
def _process(obj, filename):
write_pickle_gz(obj, filename)
filenames = [f"{dirname}/{i:05d}.pkl.gz" for i in range(len(objs))]
p_umap(_process, objs, filenames, num_cpus=n_jobs)
return filenames