LearningSolver: return model

dev
Alinson S. Xavier 1 year ago
parent 7f273ebb70
commit f085ab538b
Signed by: isoron
GPG Key ID: 0DA8E4B9E1109DCA

Binary file not shown.

@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import List, Any, Union, Dict, Callable, Optional
from typing import List, Any, Union, Dict, Callable, Optional, Tuple
from miplearn.h5 import H5File
from miplearn.io import _to_h5_filename
@ -25,7 +25,7 @@ class LearningSolver:
self,
model: Union[str, AbstractModel],
build_model: Optional[Callable] = None,
) -> Dict[str, Any]:
) -> Tuple[AbstractModel, Dict[str, Any]]:
h5_filename, mode = NamedTemporaryFile().name, "w"
if isinstance(model, str):
assert build_model is not None
@ -51,4 +51,4 @@ class LearningSolver:
model.optimize()
model.extract_after_mip(h5)
return stats
return model, stats

@ -71,5 +71,5 @@ def test_usage_stab(
comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp])
solver.fit(data_filenames)
stats = solver.optimize(data_filenames[0], build_model) # type: ignore
model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
assert stats["Cuts: AOT"] > 0

@ -65,5 +65,5 @@ def test_usage_tsp(
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
solver = LearningSolver(components=[comp])
solver.fit(data_filenames)
stats = solver.optimize(data_filenames[0], build_model) # type: ignore
model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
assert stats["Lazy Constraints: AOT"] > 0

Loading…
Cancel
Save