mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
LearningSolver: return model
This commit is contained in:
BIN
miplearn/.io.py.swp
Normal file
BIN
miplearn/.io.py.swp
Normal file
Binary file not shown.
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import List, Any, Union, Dict, Callable, Optional
|
from typing import List, Any, Union, Dict, Callable, Optional, Tuple
|
||||||
|
|
||||||
from miplearn.h5 import H5File
|
from miplearn.h5 import H5File
|
||||||
from miplearn.io import _to_h5_filename
|
from miplearn.io import _to_h5_filename
|
||||||
@@ -25,7 +25,7 @@ class LearningSolver:
|
|||||||
self,
|
self,
|
||||||
model: Union[str, AbstractModel],
|
model: Union[str, AbstractModel],
|
||||||
build_model: Optional[Callable] = None,
|
build_model: Optional[Callable] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Tuple[AbstractModel, Dict[str, Any]]:
|
||||||
h5_filename, mode = NamedTemporaryFile().name, "w"
|
h5_filename, mode = NamedTemporaryFile().name, "w"
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
assert build_model is not None
|
assert build_model is not None
|
||||||
@@ -51,4 +51,4 @@ class LearningSolver:
|
|||||||
model.optimize()
|
model.optimize()
|
||||||
model.extract_after_mip(h5)
|
model.extract_after_mip(h5)
|
||||||
|
|
||||||
return stats
|
return model, stats
|
||||||
|
|||||||
@@ -71,5 +71,5 @@ def test_usage_stab(
|
|||||||
comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor)
|
comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor)
|
||||||
solver = LearningSolver(components=[comp])
|
solver = LearningSolver(components=[comp])
|
||||||
solver.fit(data_filenames)
|
solver.fit(data_filenames)
|
||||||
stats = solver.optimize(data_filenames[0], build_model) # type: ignore
|
model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
|
||||||
assert stats["Cuts: AOT"] > 0
|
assert stats["Cuts: AOT"] > 0
|
||||||
|
|||||||
@@ -65,5 +65,5 @@ def test_usage_tsp(
|
|||||||
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
|
comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor)
|
||||||
solver = LearningSolver(components=[comp])
|
solver = LearningSolver(components=[comp])
|
||||||
solver.fit(data_filenames)
|
solver.fit(data_filenames)
|
||||||
stats = solver.optimize(data_filenames[0], build_model) # type: ignore
|
model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore
|
||||||
assert stats["Lazy Constraints: AOT"] > 0
|
assert stats["Lazy Constraints: AOT"] > 0
|
||||||
|
|||||||
Reference in New Issue
Block a user