Add types to solvers

This commit is contained in:
2021-04-07 20:58:44 -05:00
parent 38212fb858
commit 331ee5914d
5 changed files with 35 additions and 19 deletions

View File

@@ -24,6 +24,7 @@ from miplearn.types import (
UserCutCallback,
Solution,
VariableName,
Constraint,
)
logger = logging.getLogger(__name__)
@@ -158,7 +159,7 @@ class GurobiSolver(InternalSolver):
iteration_cb = lambda: False
# Create callback wrapper
def cb_wrapper(cb_model, cb_where):
def cb_wrapper(cb_model: Any, cb_where: int) -> None:
try:
self.cb_where = cb_where
if lazy_cb is not None and cb_where in self.lazy_cb_where:
@@ -323,7 +324,8 @@ class GurobiSolver(InternalSolver):
var.ub = value
@overrides
def get_constraint_ids(self):
def get_constraint_ids(self) -> List[str]:
assert self.model is not None
self._raise_if_callback()
self.model.update()
return [c.ConstrName for c in self.model.getConstrs()]
@@ -344,15 +346,20 @@ class GurobiSolver(InternalSolver):
return lhs
@overrides
def extract_constraint(self, cid):
def extract_constraint(self, cid: str) -> Constraint:
self._raise_if_callback()
assert self.model is not None
constr = self.model.getConstrByName(cid)
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName)
self.model.remove(constr)
return cobj
@overrides
def is_constraint_satisfied(self, cobj, tol=1e-6):
def is_constraint_satisfied(
self,
cobj: Constraint,
tol: float = 1e-6,
) -> bool:
lhs, sense, rhs, name = cobj
if self.cb_where is not None:
lhs_value = lhs.getConstant()
@@ -416,13 +423,13 @@ class GurobiSolver(InternalSolver):
value = matches[0]
return value
def __getstate__(self):
def __getstate__(self) -> Dict:
return {
"params": self.params,
"lazy_cb_where": self.lazy_cb_where,
}
def __setstate__(self, state):
def __setstate__(self, state: Dict) -> None:
self.params = state["params"]
self.lazy_cb_where = state["lazy_cb_where"]
self.instance = None