mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Implement save function
This commit is contained in:
@@ -19,6 +19,7 @@ from .instance.picklegz import (
|
|||||||
write_pickle_gz,
|
write_pickle_gz,
|
||||||
read_pickle_gz,
|
read_pickle_gz,
|
||||||
write_pickle_gz_multiple,
|
write_pickle_gz_multiple,
|
||||||
|
save,
|
||||||
)
|
)
|
||||||
from .log import setup_logger
|
from .log import setup_logger
|
||||||
from .solvers.gurobi import GurobiSolver
|
from .solvers.gurobi import GurobiSolver
|
||||||
|
|||||||
@@ -153,3 +153,27 @@ def read_pickle_gz(filename: str) -> Any:
|
|||||||
def write_pickle_gz_multiple(objs: List[Any], dirname: str) -> None:
|
def write_pickle_gz_multiple(objs: List[Any], dirname: str) -> None:
|
||||||
for (i, obj) in enumerate(objs):
|
for (i, obj) in enumerate(objs):
|
||||||
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
write_pickle_gz(obj, f"{dirname}/{i:05d}.pkl.gz")
|
||||||
|
|
||||||
|
|
||||||
|
def save(objs: List[Any], dirname: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Saves the provided objects to gzipped pickled files. Files are named sequentially
|
||||||
|
as `dirname/00000.pkl.gz`, `dirname/00001.pkl.gz`, etc.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
objs: List[any]
|
||||||
|
List of files to save
|
||||||
|
dirname: str
|
||||||
|
Output directory
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List containing the relative paths of the saved files.
|
||||||
|
"""
|
||||||
|
filenames = []
|
||||||
|
for (i, obj) in enumerate(objs):
|
||||||
|
filename = f"{dirname}/{i:05d}.pkl.gz"
|
||||||
|
filenames.append(filename)
|
||||||
|
write_pickle_gz(obj, filename)
|
||||||
|
return filenames
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import cast, IO
|
||||||
|
|
||||||
from miplearn.instance.picklegz import write_pickle_gz, PickleGzInstance
|
from miplearn.instance.picklegz import write_pickle_gz, PickleGzInstance
|
||||||
from miplearn.solvers.gurobi import GurobiSolver
|
from miplearn.solvers.gurobi import GurobiSolver
|
||||||
|
from miplearn import save
|
||||||
|
from os.path import exists
|
||||||
|
import gzip
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
def test_usage() -> None:
|
def test_usage() -> None:
|
||||||
@@ -14,3 +20,14 @@ def test_usage() -> None:
|
|||||||
pickled = PickleGzInstance(file.name)
|
pickled = PickleGzInstance(file.name)
|
||||||
pickled.load()
|
pickled.load()
|
||||||
assert pickled.to_model() is not None
|
assert pickled.to_model() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_save() -> None:
|
||||||
|
objs = [1, "ABC", True]
|
||||||
|
with tempfile.TemporaryDirectory() as dirname:
|
||||||
|
filenames = save(objs, dirname)
|
||||||
|
assert len(filenames) == 3
|
||||||
|
for (idx, f) in enumerate(filenames):
|
||||||
|
assert exists(f)
|
||||||
|
with gzip.GzipFile(f, "rb") as file:
|
||||||
|
assert pickle.load(cast(IO[bytes], file)) == objs[idx]
|
||||||
|
|||||||
Reference in New Issue
Block a user