You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
2.6 KiB
97 lines
2.6 KiB
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
from gzip import GzipFile
|
|
import os
|
|
import pickle
|
|
import sys
|
|
from typing import IO, Any, Callable, List, cast, TextIO
|
|
|
|
from .parallel import p_umap
|
|
import shutil
|
|
|
|
|
|
class _RedirectOutput:
|
|
def __init__(self, streams: List[Any]) -> None:
|
|
self.streams = streams
|
|
|
|
def write(self, data: Any) -> None:
|
|
for stream in self.streams:
|
|
stream.write(data)
|
|
|
|
def flush(self) -> None:
|
|
for stream in self.streams:
|
|
stream.flush()
|
|
|
|
def __enter__(self) -> Any:
|
|
self._original_stdout = sys.stdout
|
|
self._original_stderr = sys.stderr
|
|
sys.stdout = cast(TextIO, self)
|
|
sys.stderr = cast(TextIO, self)
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
_type: Any,
|
|
_value: Any,
|
|
_traceback: Any,
|
|
) -> None:
|
|
sys.stdout = self._original_stdout
|
|
sys.stderr = self._original_stderr
|
|
|
|
|
|
def write_pkl_gz(
|
|
objs: List[Any],
|
|
dirname: str,
|
|
prefix: str = "",
|
|
n_jobs: int = 1,
|
|
progress: bool = False,
|
|
) -> List[str]:
|
|
filenames = [f"{dirname}/{prefix}{i:05d}.pkl.gz" for i in range(len(objs))]
|
|
|
|
def _process(i: int) -> None:
|
|
filename = filenames[i]
|
|
obj = objs[i]
|
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
with GzipFile(filename, "wb") as file:
|
|
pickle.dump(obj, cast(IO[bytes], file))
|
|
|
|
if n_jobs > 1:
|
|
p_umap(
|
|
_process,
|
|
range(len(objs)),
|
|
smoothing=0,
|
|
num_cpus=n_jobs,
|
|
maxtasksperchild=None,
|
|
disable=not progress,
|
|
)
|
|
else:
|
|
for i in range(len(objs)):
|
|
_process(i)
|
|
return filenames
|
|
|
|
|
|
def gzip(filename: str) -> None:
|
|
with open(filename, "rb") as input_file:
|
|
with GzipFile(f"{filename}.gz", "wb") as output_file:
|
|
shutil.copyfileobj(input_file, output_file)
|
|
os.remove(filename)
|
|
|
|
|
|
def read_pkl_gz(filename: str) -> Any:
|
|
with GzipFile(filename, "rb") as file:
|
|
return pickle.load(cast(IO[bytes], file))
|
|
|
|
|
|
def _to_h5_filename(data_filename: str) -> str:
|
|
output = f"{data_filename}.h5"
|
|
output = output.replace(".gz.h5", ".h5")
|
|
output = output.replace(".csv.h5", ".h5")
|
|
output = output.replace(".jld2.h5", ".h5")
|
|
output = output.replace(".json.h5", ".h5")
|
|
output = output.replace(".lp.h5", ".h5")
|
|
output = output.replace(".mps.h5", ".h5")
|
|
output = output.replace(".pkl.h5", ".h5")
|
|
return output
|