LearningSolver: Keep original H5 file unmodified

This commit is contained in:
2024-02-02 14:34:20 -06:00
parent 687c271d4d
commit e75850fab8
5 changed files with 25 additions and 8 deletions

View File

@@ -8,6 +8,7 @@ from typing import List, Any, Union, Dict, Callable, Optional
from miplearn.h5 import H5File
from miplearn.io import _to_h5_filename
from miplearn.solvers.abstract import AbstractModel
import shutil
class LearningSolver:
@@ -25,15 +26,20 @@ class LearningSolver:
model: Union[str, AbstractModel],
build_model: Optional[Callable] = None,
) -> Dict[str, Any]:
h5_filename, mode = NamedTemporaryFile().name, "w"
if isinstance(model, str):
h5_filename = _to_h5_filename(model)
assert build_model is not None
old_h5_filename = _to_h5_filename(model)
model = build_model(model)
assert isinstance(model, AbstractModel)
else:
h5_filename = NamedTemporaryFile().name
# If the instance has an associate H5 file, we make a temporary copy of it,
# then work on that copy. We keep the original file unmodified
if exists(old_h5_filename):
shutil.copy(old_h5_filename, h5_filename)
mode = "r+"
stats: Dict[str, Any] = {}
mode = "r+" if exists(h5_filename) else "w"
with H5File(h5_filename, mode) as h5:
model.extract_after_load(h5)
if not self.skip_lp: