|
|
|
@ -3,7 +3,7 @@
|
|
|
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
|
from os.path import exists
|
|
|
|
|
from tempfile import NamedTemporaryFile
|
|
|
|
|
from typing import List, Any, Union, Dict, Callable, Optional
|
|
|
|
|
from typing import List, Any, Union, Dict, Callable, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
from miplearn.h5 import H5File
|
|
|
|
|
from miplearn.io import _to_h5_filename
|
|
|
|
@ -25,7 +25,7 @@ class LearningSolver:
|
|
|
|
|
self,
|
|
|
|
|
model: Union[str, AbstractModel],
|
|
|
|
|
build_model: Optional[Callable] = None,
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
) -> Tuple[AbstractModel, Dict[str, Any]]:
|
|
|
|
|
h5_filename, mode = NamedTemporaryFile().name, "w"
|
|
|
|
|
if isinstance(model, str):
|
|
|
|
|
assert build_model is not None
|
|
|
|
@ -51,4 +51,4 @@ class LearningSolver:
|
|
|
|
|
model.optimize()
|
|
|
|
|
model.extract_after_mip(h5)
|
|
|
|
|
|
|
|
|
|
return stats
|
|
|
|
|
return model, stats
|
|
|
|
|