Add types to solvers

master
Alinson S. Xavier 5 years ago
parent 38212fb858
commit 331ee5914d
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -4,13 +4,13 @@
import logging
import sys
from typing import Any, List
from typing import Any, List, TextIO, cast
logger = logging.getLogger(__name__)
class _RedirectOutput:
def __init__(self, streams: List[Any]):
def __init__(self, streams: List[Any]) -> None:
self.streams = streams
def write(self, data: Any) -> None:
@ -21,13 +21,18 @@ class _RedirectOutput:
for stream in self.streams:
stream.flush()
def __enter__(self):
def __enter__(self) -> Any:
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr
sys.stdout = self
sys.stderr = self
sys.stdout = cast(TextIO, self)
sys.stderr = cast(TextIO, self)
return self
def __exit__(self, _type, _value, _traceback):
def __exit__(
self,
_type: Any,
_value: Any,
_traceback: Any,
) -> None:
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr

@ -24,6 +24,7 @@ from miplearn.types import (
UserCutCallback,
Solution,
VariableName,
Constraint,
)
logger = logging.getLogger(__name__)
@ -158,7 +159,7 @@ class GurobiSolver(InternalSolver):
iteration_cb = lambda: False
# Create callback wrapper
def cb_wrapper(cb_model, cb_where):
def cb_wrapper(cb_model: Any, cb_where: int) -> None:
try:
self.cb_where = cb_where
if lazy_cb is not None and cb_where in self.lazy_cb_where:
@ -323,7 +324,8 @@ class GurobiSolver(InternalSolver):
var.ub = value
@overrides
def get_constraint_ids(self):
def get_constraint_ids(self) -> List[str]:
assert self.model is not None
self._raise_if_callback()
self.model.update()
return [c.ConstrName for c in self.model.getConstrs()]
@ -344,15 +346,20 @@ class GurobiSolver(InternalSolver):
return lhs
@overrides
def extract_constraint(self, cid):
def extract_constraint(self, cid: str) -> Constraint:
self._raise_if_callback()
assert self.model is not None
constr = self.model.getConstrByName(cid)
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName)
self.model.remove(constr)
return cobj
@overrides
def is_constraint_satisfied(self, cobj, tol=1e-6):
def is_constraint_satisfied(
self,
cobj: Constraint,
tol: float = 1e-6,
) -> bool:
lhs, sense, rhs, name = cobj
if self.cb_where is not None:
lhs_value = lhs.getConstant()
@ -416,13 +423,13 @@ class GurobiSolver(InternalSolver):
value = matches[0]
return value
def __getstate__(self):
def __getstate__(self) -> Dict:
return {
"params": self.params,
"lazy_cb_where": self.lazy_cb_where,
}
def __setstate__(self, state):
def __setstate__(self, state: Dict) -> None:
self.params = state["params"]
self.lazy_cb_where = state["lazy_cb_where"]
self.instance = None

@ -4,7 +4,7 @@
import logging
import traceback
from typing import Optional, List, Any, cast, Callable, Dict
from typing import Optional, List, Any, cast, Callable, Dict, Tuple
from p_tqdm import p_map
@ -37,10 +37,14 @@ class _GlobalVariables:
_GLOBAL = [_GlobalVariables()]
def _parallel_solve(idx):
def _parallel_solve(
idx: int,
) -> Tuple[Optional[LearningSolveStats], Optional[Instance]]:
solver = _GLOBAL[0].solver
instances = _GLOBAL[0].instances
discard_outputs = _GLOBAL[0].discard_outputs
assert solver is not None
assert instances is not None
try:
stats = solver.solve(
instances[idx],

@ -228,7 +228,7 @@ class BasePyomoSolver(InternalSolver):
self._pyomo_solver.update_var(var)
@overrides
def add_constraint(self, constraint):
def add_constraint(self, constraint: Any) -> Any:
self._pyomo_solver.add_constraint(constraint)
self._update_constrs()
@ -261,7 +261,7 @@ class BasePyomoSolver(InternalSolver):
return int(value)
@overrides
def get_constraint_ids(self):
def get_constraint_ids(self) -> List[str]:
return list(self._cname_to_constr.keys())
def _get_warm_start_regexp(self) -> Optional[str]:
@ -332,7 +332,7 @@ class BasePyomoSolver(InternalSolver):
return self._termination_condition == TerminationCondition.infeasible
@overrides
def get_dual(self, cid):
def get_dual(self, cid: str) -> float:
raise NotImplementedError()
@overrides

@ -37,11 +37,11 @@ class CplexPyomoSolver(BasePyomoSolver):
)
@overrides
def _get_warm_start_regexp(self):
def _get_warm_start_regexp(self) -> str:
return "MIP start .* with objective ([0-9.e+-]*)\\."
@overrides
def _get_node_count_regexp(self):
def _get_node_count_regexp(self) -> str:
return "^[ *] *([0-9]+)"
@overrides

Loading…
Cancel
Save