diff --git a/.mypy.ini b/.mypy.ini index 976ba02..60bfdf9 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -1,2 +1,5 @@ [mypy] ignore_missing_imports = True +#disallow_untyped_defs = True +disallow_untyped_calls = True +disallow_incomplete_defs = True diff --git a/miplearn/instance.py b/miplearn/instance.py index 850fa1a..6610314 100644 --- a/miplearn/instance.py +++ b/miplearn/instance.py @@ -5,6 +5,7 @@ import gzip import json from abc import ABC, abstractmethod +from typing import Any import numpy as np @@ -20,7 +21,7 @@ class Instance(ABC): """ @abstractmethod - def to_model(self): + def to_model(self) -> Any: """ Returns a concrete Pyomo model corresponding to this instance. """ diff --git a/miplearn/solvers/__init__.py b/miplearn/solvers/__init__.py index ae0d141..37ccc49 100644 --- a/miplearn/solvers/__init__.py +++ b/miplearn/solvers/__init__.py @@ -4,19 +4,20 @@ import logging import sys +from typing import Any, List logger = logging.getLogger(__name__) class RedirectOutput: - def __init__(self, streams): + def __init__(self, streams: List[Any]): self.streams = streams - def write(self, data): + def write(self, data: Any) -> None: for stream in self.streams: stream.write(data) - def flush(self): + def flush(self) -> None: for stream in self.streams: stream.flush() diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 88b494a..a888a71 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -6,7 +6,7 @@ import re import sys from io import StringIO from random import randint -from typing import List, Any, Dict, Union +from typing import List, Any, Dict, Union, Tuple, Optional from . import RedirectOutput from .internal import ( @@ -73,13 +73,14 @@ class GurobiSolver(InternalSolver): self.model.update() self._update_vars() - def _raise_if_callback(self): + def _raise_if_callback(self) -> None: if self.cb_where is not None: raise Exception("method cannot be called from a callback") - def _update_vars(self): + def _update_vars(self) -> None: self._all_vars = {} self._bin_vars = {} + idx: Union[Tuple, List[int], int] for var in self.model.getVars(): m = re.search(r"([^[]*)\[(.*)]", var.varName) if m is None: @@ -100,7 +101,7 @@ class GurobiSolver(InternalSolver): self._bin_vars[name] = {} self._bin_vars[name][idx] = var - def _apply_params(self, streams): + def _apply_params(self, streams: List[Any]) -> None: with RedirectOutput(streams): for (name, value) in self.params.items(): self.model.setParam(name, value) @@ -271,7 +272,7 @@ class GurobiSolver(InternalSolver): else: self.model.addConstr(constraint, name=name) - def _clear_warm_start(self): + def _clear_warm_start(self) -> None: for (varname, vardict) in self._all_vars.items(): for (idx, var) in vardict.items(): var.start = self.GRB.UNDEFINED @@ -338,14 +339,18 @@ class GurobiSolver(InternalSolver): self.model = self.model.relax() self._update_vars() - def _extract_warm_start_value(self, log): + def _extract_warm_start_value(self, log: str) -> Optional[float]: ws = self.__extract(log, "MIP start with objective ([0-9.e+-]*)") - if ws is not None: - ws = float(ws) - return ws + if ws is None: + return None + return float(ws) @staticmethod - def __extract(log, regexp, default=None): + def __extract( + log: str, + regexp: str, + default: Optional[str] = None, + ) -> Optional[str]: value = default for line in log.splitlines(): matches = re.findall(regexp, line) diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index f3335d4..d7aa3d3 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -192,7 +192,7 @@ class InternalSolver(ABC): pass @abstractmethod - def add_constraint(self, cobj: Constraint): + def add_constraint(self, cobj: Constraint) -> None: """ Adds a single constraint to the model. """ @@ -209,7 +209,7 @@ class InternalSolver(ABC): pass @abstractmethod - def is_constraint_satisfied(self, cobj: Constraint): + def is_constraint_satisfied(self, cobj: Constraint) -> bool: """ Returns True if the current solution satisfies the given constraint. """ diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index 7d953cb..a6b5b7a 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -6,7 +6,7 @@ import logging import re import sys from io import StringIO -from typing import Any, List, Dict +from typing import Any, List, Dict, Optional import pyomo from pyomo import environ as pe @@ -169,18 +169,18 @@ class BasePyomoSolver(InternalSolver): variables[str(var)] += [index] return variables - def _clear_warm_start(self): + def _clear_warm_start(self) -> None: for var in self._all_vars: if not var.fixed: var.value = None self._is_warm_start_available = False - def _update_obj(self): + def _update_obj(self) -> None: self._obj_sense = "max" if self._pyomo_solver._objective.sense == pyomo.core.kernel.objective.minimize: self._obj_sense = "min" - def _update_vars(self): + def _update_vars(self) -> None: self._all_vars = [] self._bin_vars = [] self._varname_to_var = {} @@ -191,7 +191,7 @@ class BasePyomoSolver(InternalSolver): if var[idx].domain == pyomo.core.base.set_types.Binary: self._bin_vars += [var[idx]] - def _update_constrs(self): + def _update_constrs(self) -> None: self._cname_to_constr = {} for constr in self.model.component_objects(Constraint): self._cname_to_constr[constr.name] = constr @@ -220,7 +220,11 @@ class BasePyomoSolver(InternalSolver): self._update_constrs() @staticmethod - def __extract(log, regexp, default=None): + def __extract( + log: str, + regexp: Optional[str], + default: Optional[str] = None, + ) -> Optional[str]: if regexp is None: return default value = default @@ -231,22 +235,25 @@ class BasePyomoSolver(InternalSolver): value = matches[0] return value - def _extract_warm_start_value(self, log): + def _extract_warm_start_value(self, log: str) -> Optional[float]: value = self.__extract(log, self._get_warm_start_regexp()) - if value is not None: - value = float(value) - return value + if value is None: + return None + return float(value) - def _extract_node_count(self, log): - return self.__extract(log, self._get_node_count_regexp()) + def _extract_node_count(self, log: str) -> Optional[int]: + value = self.__extract(log, self._get_node_count_regexp()) + if value is None: + return None + return int(value) def get_constraint_ids(self): return list(self._cname_to_constr.keys()) - def _get_warm_start_regexp(self): + def _get_warm_start_regexp(self) -> Optional[str]: return None - def _get_node_count_regexp(self): + def _get_node_count_regexp(self) -> Optional[str]: return None def extract_constraint(self, cid):