Fix mypy errors

This commit is contained in:
2023-10-26 13:39:57 -05:00
parent e555dffc0c
commit 2d07a44f7d
12 changed files with 61 additions and 41 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import List, Any, Union
from typing import List, Any, Union, Dict, Callable, Optional
from miplearn.h5 import H5File
from miplearn.io import _to_h5_filename
@@ -11,23 +11,28 @@ from miplearn.solvers.abstract import AbstractModel
class LearningSolver:
def __init__(self, components: List[Any], skip_lp=False):
def __init__(self, components: List[Any], skip_lp: bool = False) -> None:
self.components = components
self.skip_lp = skip_lp
def fit(self, data_filenames):
def fit(self, data_filenames: List[str]) -> None:
h5_filenames = [_to_h5_filename(f) for f in data_filenames]
for comp in self.components:
comp.fit(h5_filenames)
def optimize(self, model: Union[str, AbstractModel], build_model=None):
def optimize(
self,
model: Union[str, AbstractModel],
build_model: Optional[Callable] = None,
) -> Dict[str, Any]:
if isinstance(model, str):
h5_filename = _to_h5_filename(model)
assert build_model is not None
model = build_model(model)
assert isinstance(model, AbstractModel)
else:
h5_filename = NamedTemporaryFile().name
stats = {}
stats: Dict[str, Any] = {}
mode = "r+" if exists(h5_filename) else "w"
with H5File(h5_filename, mode) as h5:
model.extract_after_load(h5)