LearningSolver: return model

This commit is contained in:
2024-05-31 11:53:56 -05:00
parent 7f273ebb70
commit f085ab538b
4 changed files with 5 additions and 5 deletions

BIN
miplearn/.io.py.swp Normal file

Binary file not shown.

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from os.path import exists from os.path import exists
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from typing import List, Any, Union, Dict, Callable, Optional from typing import List, Any, Union, Dict, Callable, Optional, Tuple
from miplearn.h5 import H5File from miplearn.h5 import H5File
from miplearn.io import _to_h5_filename from miplearn.io import _to_h5_filename
@@ -25,7 +25,7 @@ class LearningSolver:
self, self,
model: Union[str, AbstractModel], model: Union[str, AbstractModel],
build_model: Optional[Callable] = None, build_model: Optional[Callable] = None,
) -> Dict[str, Any]: ) -> Tuple[AbstractModel, Dict[str, Any]]:
h5_filename, mode = NamedTemporaryFile().name, "w" h5_filename, mode = NamedTemporaryFile().name, "w"
if isinstance(model, str): if isinstance(model, str):
assert build_model is not None assert build_model is not None
@@ -51,4 +51,4 @@ class LearningSolver:
model.optimize() model.optimize()
model.extract_after_mip(h5) model.extract_after_mip(h5)
return stats return model, stats

View File

@@ -71,5 +71,5 @@ def test_usage_stab(
comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor) comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp]) solver = LearningSolver(components=[comp])
solver.fit(data_filenames) solver.fit(data_filenames)
stats = solver.optimize(data_filenames[0], build_model) # type: ignore model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
assert stats["Cuts: AOT"] > 0 assert stats["Cuts: AOT"] > 0

View File

@@ -65,5 +65,5 @@ def test_usage_tsp(
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor) comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp]) solver = LearningSolver(components=[comp])
solver.fit(data_filenames) solver.fit(data_filenames)
stats = solver.optimize(data_filenames[0], build_model) # type: ignore model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
assert stats["Lazy Constraints: AOT"] > 0 assert stats["Lazy Constraints: AOT"] > 0