diff --git a/miplearn/__init__.py b/miplearn/__init__.py index 692281f..071ff3c 100644 --- a/miplearn/__init__.py +++ b/miplearn/__init__.py @@ -19,6 +19,7 @@ from .instance.picklegz import ( write_pickle_gz, read_pickle_gz, write_pickle_gz_multiple, + save, ) from .log import setup_logger from .solvers.gurobi import GurobiSolver diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index bdceae7..1a9db99 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -153,3 +153,27 @@ def read_pickle_gz(filename: str) -> Any: def write_pickle_gz_multiple(objs: List[Any], dirname: str) -> None: for (i, obj) in enumerate(objs): write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz") + + +def save(objs: List[Any], dirname: str) -> List[str]: + """ + Saves the provided objects to gzipped pickled files. Files are named sequentially + as `dirname/00000.pkl.gz`, `dirname/00001.pkl.gz`, etc. + + Parameters + ---------- + objs: List[any] + List of files to save + dirname: str + Output directory + + Returns + ------- + List containing the relative paths of the saved files. + """ + filenames = [] + for (i, obj) in enumerate(objs): + filename = f"{dirname}/{i:05d}.pkl.gz" + filenames.append(filename) + write_pickle_gz(obj, filename) + return filenames diff --git a/tests/instance/test_picklegz.py b/tests/instance/test_picklegz.py index ebdb017..e7b14e3 100644 --- a/tests/instance/test_picklegz.py +++ b/tests/instance/test_picklegz.py @@ -1,10 +1,16 @@ # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. + import tempfile +from typing import cast, IO from miplearn.instance.picklegz import write_pickle_gz, PickleGzInstance from miplearn.solvers.gurobi import GurobiSolver +from miplearn import save +from os.path import exists +import gzip +import pickle def test_usage() -> None: @@ -14,3 +20,14 @@ def test_usage() -> None: pickled = PickleGzInstance(file.name) pickled.load() assert pickled.to_model() is not None + + +def test_save() -> None: + objs = [1, "ABC", True] + with tempfile.TemporaryDirectory() as dirname: + filenames = save(objs, dirname) + assert len(filenames) == 3 + for (idx, f) in enumerate(filenames): + assert exists(f) + with gzip.GzipFile(f, "rb") as file: + assert pickle.load(cast(IO[bytes], file)) == objs[idx]