mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
LearningSolver: Keep original H5 file unmodified
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import List, Any, Union, Dict, Callable, Optional
|
||||
from miplearn.h5 import H5File
|
||||
from miplearn.io import _to_h5_filename
|
||||
from miplearn.solvers.abstract import AbstractModel
|
||||
import shutil
|
||||
|
||||
|
||||
class LearningSolver:
|
||||
@@ -25,15 +26,20 @@ class LearningSolver:
|
||||
model: Union[str, AbstractModel],
|
||||
build_model: Optional[Callable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
h5_filename, mode = NamedTemporaryFile().name, "w"
|
||||
if isinstance(model, str):
|
||||
h5_filename = _to_h5_filename(model)
|
||||
assert build_model is not None
|
||||
old_h5_filename = _to_h5_filename(model)
|
||||
model = build_model(model)
|
||||
assert isinstance(model, AbstractModel)
|
||||
else:
|
||||
h5_filename = NamedTemporaryFile().name
|
||||
|
||||
# If the instance has an associate H5 file, we make a temporary copy of it,
|
||||
# then work on that copy. We keep the original file unmodified
|
||||
if exists(old_h5_filename):
|
||||
shutil.copy(old_h5_filename, h5_filename)
|
||||
mode = "r+"
|
||||
|
||||
stats: Dict[str, Any] = {}
|
||||
mode = "r+" if exists(h5_filename) else "w"
|
||||
with H5File(h5_filename, mode) as h5:
|
||||
model.extract_after_load(h5)
|
||||
if not self.skip_lp:
|
||||
|
||||
Reference in New Issue
Block a user