mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-09 02:48:52 -06:00
MIPLearn v0.3
This commit is contained in:
92
miplearn/io.py
Normal file
92
miplearn/io.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from gzip import GzipFile
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from typing import IO, Any, Callable, List, cast, TextIO
|
||||
|
||||
from .parallel import p_umap
|
||||
import shutil
|
||||
|
||||
|
||||
class _RedirectOutput:
|
||||
def __init__(self, streams: List[Any]) -> None:
|
||||
self.streams = streams
|
||||
|
||||
def write(self, data: Any) -> None:
|
||||
for stream in self.streams:
|
||||
stream.write(data)
|
||||
|
||||
def flush(self) -> None:
|
||||
for stream in self.streams:
|
||||
stream.flush()
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
self._original_stdout = sys.stdout
|
||||
self._original_stderr = sys.stderr
|
||||
sys.stdout = cast(TextIO, self)
|
||||
sys.stderr = cast(TextIO, self)
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
_type: Any,
|
||||
_value: Any,
|
||||
_traceback: Any,
|
||||
) -> None:
|
||||
sys.stdout = self._original_stdout
|
||||
sys.stderr = self._original_stderr
|
||||
|
||||
|
||||
def write_pkl_gz(
|
||||
objs: List[Any],
|
||||
dirname: str,
|
||||
prefix: str = "",
|
||||
n_jobs: int = 1,
|
||||
progress: bool = False,
|
||||
) -> List[str]:
|
||||
filenames = [f"{dirname}/{prefix}{i:05d}.pkl.gz" for i in range(len(objs))]
|
||||
|
||||
def _process(i: int) -> None:
|
||||
filename = filenames[i]
|
||||
obj = objs[i]
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with GzipFile(filename, "wb") as file:
|
||||
pickle.dump(obj, cast(IO[bytes], file))
|
||||
|
||||
if n_jobs > 1:
|
||||
p_umap(
|
||||
_process,
|
||||
range(len(objs)),
|
||||
smoothing=0,
|
||||
num_cpus=n_jobs,
|
||||
maxtasksperchild=None,
|
||||
disable=not progress,
|
||||
)
|
||||
else:
|
||||
for i in range(len(objs)):
|
||||
_process(i)
|
||||
return filenames
|
||||
|
||||
|
||||
def gzip(filename: str) -> None:
|
||||
with open(filename, "rb") as input_file:
|
||||
with GzipFile(f"{filename}.gz", "wb") as output_file:
|
||||
shutil.copyfileobj(input_file, output_file)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
def read_pkl_gz(filename: str) -> Any:
|
||||
with GzipFile(filename, "rb") as file:
|
||||
return pickle.load(cast(IO[bytes], file))
|
||||
|
||||
|
||||
def _to_h5_filename(data_filename: str) -> str:
|
||||
output = f"{data_filename}.h5"
|
||||
output = output.replace(".pkl.gz.h5", ".h5")
|
||||
output = output.replace(".pkl.h5", ".h5")
|
||||
output = output.replace(".jld2.h5", ".h5")
|
||||
return output
|
||||
Reference in New Issue
Block a user