From 331ee5914dd13ab131775795573288a89bbeff82 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Wed, 7 Apr 2021 20:58:44 -0500 Subject: [PATCH] Add types to solvers --- miplearn/solvers/__init__.py | 17 +++++++++++------ miplearn/solvers/gurobi.py | 19 +++++++++++++------ miplearn/solvers/learning.py | 8 ++++++-- miplearn/solvers/pyomo/base.py | 6 +++--- miplearn/solvers/pyomo/cplex.py | 4 ++-- 5 files changed, 35 insertions(+), 19 deletions(-) diff --git a/miplearn/solvers/__init__.py b/miplearn/solvers/__init__.py index e112062..eb535da 100644 --- a/miplearn/solvers/__init__.py +++ b/miplearn/solvers/__init__.py @@ -4,13 +4,13 @@ import logging import sys -from typing import Any, List +from typing import Any, List, TextIO, cast logger = logging.getLogger(__name__) class _RedirectOutput: - def __init__(self, streams: List[Any]): + def __init__(self, streams: List[Any]) -> None: self.streams = streams def write(self, data: Any) -> None: @@ -21,13 +21,18 @@ class _RedirectOutput: for stream in self.streams: stream.flush() - def __enter__(self): + def __enter__(self) -> Any: self._original_stdout = sys.stdout self._original_stderr = sys.stderr - sys.stdout = self - sys.stderr = self + sys.stdout = cast(TextIO, self) + sys.stderr = cast(TextIO, self) return self - def __exit__(self, _type, _value, _traceback): + def __exit__( + self, + _type: Any, + _value: Any, + _traceback: Any, + ) -> None: sys.stdout = self._original_stdout sys.stderr = self._original_stderr diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index bd96c80..06266b4 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -24,6 +24,7 @@ from miplearn.types import ( UserCutCallback, Solution, VariableName, + Constraint, ) logger = logging.getLogger(__name__) @@ -158,7 +159,7 @@ class GurobiSolver(InternalSolver): iteration_cb = lambda: False # Create callback wrapper - def cb_wrapper(cb_model, cb_where): + def cb_wrapper(cb_model: Any, cb_where: int) -> None: try: self.cb_where = cb_where if lazy_cb is not None and cb_where in self.lazy_cb_where: @@ -323,7 +324,8 @@ class GurobiSolver(InternalSolver): var.ub = value @overrides - def get_constraint_ids(self): + def get_constraint_ids(self) -> List[str]: + assert self.model is not None self._raise_if_callback() self.model.update() return [c.ConstrName for c in self.model.getConstrs()] @@ -344,15 +346,20 @@ class GurobiSolver(InternalSolver): return lhs @overrides - def extract_constraint(self, cid): + def extract_constraint(self, cid: str) -> Constraint: self._raise_if_callback() + assert self.model is not None constr = self.model.getConstrByName(cid) cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName) self.model.remove(constr) return cobj @overrides - def is_constraint_satisfied(self, cobj, tol=1e-6): + def is_constraint_satisfied( + self, + cobj: Constraint, + tol: float = 1e-6, + ) -> bool: lhs, sense, rhs, name = cobj if self.cb_where is not None: lhs_value = lhs.getConstant() @@ -416,13 +423,13 @@ class GurobiSolver(InternalSolver): value = matches[0] return value - def __getstate__(self): + def __getstate__(self) -> Dict: return { "params": self.params, "lazy_cb_where": self.lazy_cb_where, } - def __setstate__(self, state): + def __setstate__(self, state: Dict) -> None: self.params = state["params"] self.lazy_cb_where = state["lazy_cb_where"] self.instance = None diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 943c83f..6ea6c31 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -4,7 +4,7 @@ import logging import traceback -from typing import Optional, List, Any, cast, Callable, Dict +from typing import Optional, List, Any, cast, Callable, Dict, Tuple from p_tqdm import p_map @@ -37,10 +37,14 @@ class _GlobalVariables: _GLOBAL = [_GlobalVariables()] -def _parallel_solve(idx): +def _parallel_solve( + idx: int, +) -> Tuple[Optional[LearningSolveStats], Optional[Instance]]: solver = _GLOBAL[0].solver instances = _GLOBAL[0].instances discard_outputs = _GLOBAL[0].discard_outputs + assert solver is not None + assert instances is not None try: stats = solver.solve( instances[idx], diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index 589b264..311dfa5 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -228,7 +228,7 @@ class BasePyomoSolver(InternalSolver): self._pyomo_solver.update_var(var) @overrides - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any) -> Any: self._pyomo_solver.add_constraint(constraint) self._update_constrs() @@ -261,7 +261,7 @@ class BasePyomoSolver(InternalSolver): return int(value) @overrides - def get_constraint_ids(self): + def get_constraint_ids(self) -> List[str]: return list(self._cname_to_constr.keys()) def _get_warm_start_regexp(self) -> Optional[str]: @@ -332,7 +332,7 @@ class BasePyomoSolver(InternalSolver): return self._termination_condition == TerminationCondition.infeasible @overrides - def get_dual(self, cid): + def get_dual(self, cid: str) -> float: raise NotImplementedError() @overrides diff --git a/miplearn/solvers/pyomo/cplex.py b/miplearn/solvers/pyomo/cplex.py index e80474e..2cf8ddd 100644 --- a/miplearn/solvers/pyomo/cplex.py +++ b/miplearn/solvers/pyomo/cplex.py @@ -37,11 +37,11 @@ class CplexPyomoSolver(BasePyomoSolver): ) @overrides - def _get_warm_start_regexp(self): + def _get_warm_start_regexp(self) -> str: return "MIP start .* with objective ([0-9.e+-]*)\\." @overrides - def _get_node_count_regexp(self): + def _get_node_count_regexp(self) -> str: return "^[ *] *([0-9]+)" @overrides