From f085ab538baf58abe84f903313072109f7f33f1c Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Fri, 31 May 2024 11:53:56 -0500 Subject: [PATCH] LearningSolver: return model --- miplearn/.io.py.swp | Bin 0 -> 12288 bytes miplearn/solvers/learning.py | 6 +++--- tests/components/cuts/test_mem.py | 2 +- tests/components/lazy/test_mem.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) create mode 100644 miplearn/.io.py.swp diff --git a/miplearn/.io.py.swp b/miplearn/.io.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..898ac0df50795fa2d643cf5140a7fb74001933cc GIT binary patch literal 12288 zcmeI2O^6&t6vrzGYGRBA1rvy1IY#M`+1|;9HDO?*ar43ACcETA4C~VAnXZ|sOn=c; zvpch4%^?Sk7q5Z|8t|m%EcgKm9`xunCj)vCyd;7L1^-pmJ+r$e3v&!q!*6@K>eYL% z{`IPQV5eSvYvvR^>`o9|y9s&r?+eZ6#$VYsd1nvtd;StIa^CTLGDFAy+Vkt@$)A=Jzcd?)js= zSp}>DRspMkRlq7>6|f3e1*`&A0jq#jz$)+%D&Pl%+}TFR>LW-V|Np<3&*9cngfOrd z>;eBiNyuIBJGc#Qf=|I!&;=3*@D6am4zL~kx{Hu+!3}U7TmcKzQ+hd`~xEO-;x%PL?McsLZGxH2ZYSxl3zOuLe(7W0u# z(oPRUJG_ecXeq1&D;RcJ}w%0!KSu=^+#BUp#&2*#I05|3gY1`{KgH^kBi z)~%HX8#har17G@HlZPz!BetqcNx+)alZn@!9CSGA%xYA~tW3v_sCx~?r-W8lzE~-5 zbVBBai61x(Ur1+WZlT_nOe|KY3RN~sv&?0io?PMS3A8yy4p_acUMf>xXxo}jCu4r# zrJc}gtqcr@Ssab~MHY!b0Y%Bu(4UEA@}pF=yOM`)BT4%zU{b&2@EE12XIjatML!J3 zZOF59xkz=@ZLO$kTh8+}wWdech>EC3CMk;tNw1mRFs_+n3rY;^hLby;#H^TIu5S@v z=pZVPy1X4^wX@8ReA6H1p#YbFj(|rYy>3lOGEOeCyHpGaRN`0xtY|$fn{% zQh073S4_p;rKG-Lyoi!Swt3u|s;(ck zG1>LK#XN_tsn?DC)FCP{-lQfLJwgvaH`P{Ld0htGC{;34=qAOf#%x&le3PtMaJ3La z{ZM<6-(dmIgoDqiFWW92)W#aDYq>&Mj|=G~ovCwKmkl*XLh2C2>MZa@QlwlgHM)wA zgMPhF2a8_AGE2QF^NLX}24gUOLG5;v&g%H(G-FNPTfNL{)yZJh#okiIOdnV?54N5k%3 ztmsolk&;PMcP#l~mu+m6{d8lS$}FF+I$=E&g7iFRSfly&R(3>rR@b_@>3TiK zdU9s2q9-~mpLG-4%L4=1?wtsXNMs z`H|ZUyQ1w3Ggl>Ny3|3XgPKQSYE1bzV+Ow0RCo~}cdkoBzpX|KV8*Ec6{_BKdO6or z0U5+mqU8B0+4P*uL2QKJv0hcIlGpfbguW2eJEKkv){t}hJCQo-d6XuZ%->a{ znI#d$X_EP2$U>TTnlq1>UKAUS(V|NghpJFIjS=b$ujugRxT{c8KZ78BQG7JVWAT&- z<6FI!3eUyM8o!UF^bEt#(FF@=H^$bZvdw6e1iZ=7^v2m^G~^8y3+B?ZjM3@2`3p0r zPr6Z{kC-UQ7!8>8c_>`vY#J|QY}vM?&UBehR41xq_%l(V=cn7e;kOccG;1X>-Y;}^ zcG{&!!;orwfiE3`WlJp3eu`hr*XZ${M0k$r(-Rqv;N>Li=-B!f^^j}qc-;2m2C_ZH Xdn_26K?S!kk-1dzh_CpPCo%aCN57cb literal 0 HcmV?d00001 diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 38d7b08..21388e0 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. from os.path import exists from tempfile import NamedTemporaryFile -from typing import List, Any, Union, Dict, Callable, Optional +from typing import List, Any, Union, Dict, Callable, Optional, Tuple from miplearn.h5 import H5File from miplearn.io import _to_h5_filename @@ -25,7 +25,7 @@ class LearningSolver: self, model: Union[str, AbstractModel], build_model: Optional[Callable] = None, - ) -> Dict[str, Any]: + ) -> Tuple[AbstractModel, Dict[str, Any]]: h5_filename, mode = NamedTemporaryFile().name, "w" if isinstance(model, str): assert build_model is not None @@ -51,4 +51,4 @@ class LearningSolver: model.optimize() model.extract_after_mip(h5) - return stats + return model, stats diff --git a/tests/components/cuts/test_mem.py b/tests/components/cuts/test_mem.py index 4659f10..207e385 100644 --- a/tests/components/cuts/test_mem.py +++ b/tests/components/cuts/test_mem.py @@ -71,5 +71,5 @@ def test_usage_stab( comp = MemorizingCutsComponent(clf=clf, extractor=default_extractor) solver = LearningSolver(components=[comp]) solver.fit(data_filenames) - stats = solver.optimize(data_filenames[0], build_model) # type: ignore + model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore assert stats["Cuts: AOT"] > 0 diff --git a/tests/components/lazy/test_mem.py b/tests/components/lazy/test_mem.py index b3e484c..e445c9f 100644 --- a/tests/components/lazy/test_mem.py +++ b/tests/components/lazy/test_mem.py @@ -65,5 +65,5 @@ def test_usage_tsp( comp = MemorizingLazyComponent(clf=clf, extractor=default_extractor) solver = LearningSolver(components=[comp]) solver.fit(data_filenames) - stats = solver.optimize(data_filenames[0], build_model) # type: ignore + model, stats = solver.optimize(data_filenames[0], build_model) # type: ignore assert stats["Lazy Constraints: AOT"] > 0