mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 18:38:51 -06:00
Fix mypy errors
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from os.path import exists
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import List, Any, Union
|
||||
from typing import List, Any, Union, Dict, Callable, Optional
|
||||
|
||||
from miplearn.h5 import H5File
|
||||
from miplearn.io import _to_h5_filename
|
||||
@@ -11,23 +11,28 @@ from miplearn.solvers.abstract import AbstractModel
|
||||
|
||||
|
||||
class LearningSolver:
|
||||
def __init__(self, components: List[Any], skip_lp=False):
|
||||
def __init__(self, components: List[Any], skip_lp: bool = False) -> None:
|
||||
self.components = components
|
||||
self.skip_lp = skip_lp
|
||||
|
||||
def fit(self, data_filenames):
|
||||
def fit(self, data_filenames: List[str]) -> None:
|
||||
h5_filenames = [_to_h5_filename(f) for f in data_filenames]
|
||||
for comp in self.components:
|
||||
comp.fit(h5_filenames)
|
||||
|
||||
def optimize(self, model: Union[str, AbstractModel], build_model=None):
|
||||
def optimize(
|
||||
self,
|
||||
model: Union[str, AbstractModel],
|
||||
build_model: Optional[Callable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(model, str):
|
||||
h5_filename = _to_h5_filename(model)
|
||||
assert build_model is not None
|
||||
model = build_model(model)
|
||||
assert isinstance(model, AbstractModel)
|
||||
else:
|
||||
h5_filename = NamedTemporaryFile().name
|
||||
stats = {}
|
||||
stats: Dict[str, Any] = {}
|
||||
mode = "r+" if exists(h5_filename) else "w"
|
||||
with H5File(h5_filename, mode) as h5:
|
||||
model.extract_after_load(h5)
|
||||
|
||||
Reference in New Issue
Block a user