Disallow untyped calls and incomplete defs

This commit is contained in:
2021-01-20 10:48:03 -06:00
parent 7555f561f8
commit 947189f25f
6 changed files with 47 additions and 30 deletions

View File

@@ -6,7 +6,7 @@ import re
import sys
from io import StringIO
from random import randint
from typing import List, Any, Dict, Union
from typing import List, Any, Dict, Union, Tuple, Optional
from . import RedirectOutput
from .internal import (
@@ -73,13 +73,14 @@ class GurobiSolver(InternalSolver):
self.model.update()
self._update_vars()
def _raise_if_callback(self):
def _raise_if_callback(self) -> None:
if self.cb_where is not None:
raise Exception("method cannot be called from a callback")
def _update_vars(self):
def _update_vars(self) -> None:
self._all_vars = {}
self._bin_vars = {}
idx: Union[Tuple, List[int], int]
for var in self.model.getVars():
m = re.search(r"([^[]*)\[(.*)]", var.varName)
if m is None:
@@ -100,7 +101,7 @@ class GurobiSolver(InternalSolver):
self._bin_vars[name] = {}
self._bin_vars[name][idx] = var
def _apply_params(self, streams):
def _apply_params(self, streams: List[Any]) -> None:
with RedirectOutput(streams):
for (name, value) in self.params.items():
self.model.setParam(name, value)
@@ -271,7 +272,7 @@ class GurobiSolver(InternalSolver):
else:
self.model.addConstr(constraint, name=name)
def _clear_warm_start(self):
def _clear_warm_start(self) -> None:
for (varname, vardict) in self._all_vars.items():
for (idx, var) in vardict.items():
var.start = self.GRB.UNDEFINED
@@ -338,14 +339,18 @@ class GurobiSolver(InternalSolver):
self.model = self.model.relax()
self._update_vars()
def _extract_warm_start_value(self, log):
def _extract_warm_start_value(self, log: str) -> Optional[float]:
ws = self.__extract(log, "MIP start with objective ([0-9.e+-]*)")
if ws is not None:
ws = float(ws)
return ws
if ws is None:
return None
return float(ws)
@staticmethod
def __extract(log, regexp, default=None):
def __extract(
log: str,
regexp: str,
default: Optional[str] = None,
) -> Optional[str]:
value = default
for line in log.splitlines():
matches = re.findall(regexp, line)