diff --git a/miplearn/.io.py.swp b/miplearn/.io.py.swp new file mode 100644 index 0000000..898ac0d Binary files /dev/null and b/miplearn/.io.py.swp differ diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 38d7b08..21388e0 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. from os.path import exists from tempfile import NamedTemporaryFile -from typing import List, Any, Union, Dict, Callable, Optional +from typing import List, Any, Union, Dict, Callable, Optional, Tuple from miplearn.h5 import H5File from miplearn.io import _to_h5_filename @@ -25,7 +25,7 @@ class LearningSolver: self, model: Union[str, AbstractModel], build_model: Optional[Callable] = None, - ) -> Dict[str, Any]: + ) -> Tuple[AbstractModel, Dict[str, Any]]: h5_filename, mode = NamedTemporaryFile().name, "w" if isinstance(model, str): assert build_model is not None @@ -51,4 +51,4 @@ class LearningSolver: model.optimize() model.extract_after_mip(h5) - return stats + return model, stats diff --git a/tests/components/cuts/test_mem.py b/tests/components/cuts/test_mem.py index 4659f10..207e385 100644 --- a/tests/components/cuts/test_mem.py +++ b/tests/components/cuts/test_mem.py @@ -71,5 +71,5 @@ def test_usage_stab( comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor) solver = LearningSolver(components=[comp]) solver.fit(data_filenames) - stats = solver.optimize(data_filenames[0], build_model) # type: ignore + model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore assert stats["Cuts: AOT"] > 0 diff --git a/tests/components/lazy/test_mem.py b/tests/components/lazy/test_mem.py index b3e484c..e445c9f 100644 --- a/tests/components/lazy/test_mem.py +++ b/tests/components/lazy/test_mem.py @@ -65,5 +65,5 @@ def test_usage_tsp( comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor) solver = LearningSolver(components=[comp]) solver.fit(data_filenames) - stats = solver.optimize(data_filenames[0], build_model) # type: ignore + model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore assert stats["Lazy Constraints: AOT"] > 0