mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Add types to remaining InternalSolver methods
This commit is contained in:
@@ -17,6 +17,7 @@ from miplearn.solvers.internal import (
|
||||
LazyCallback,
|
||||
MIPSolveStats,
|
||||
)
|
||||
from miplearn.types import VarIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -230,7 +231,7 @@ class GurobiSolver(InternalSolver):
|
||||
else:
|
||||
return "max"
|
||||
|
||||
def get_value(self, var_name, index):
|
||||
def get_value(self, var_name: str, index: VarIndex) -> Optional[float]:
|
||||
var = self._all_vars[var_name][index]
|
||||
return self._get_value(var)
|
||||
|
||||
@@ -244,26 +245,29 @@ class GurobiSolver(InternalSolver):
|
||||
else:
|
||||
return c.pi
|
||||
|
||||
def _get_value(self, var):
|
||||
def _get_value(self, var: Any) -> Optional[float]:
|
||||
if self.cb_where == self.GRB.Callback.MIPSOL:
|
||||
return self.model.cbGetSolution(var)
|
||||
elif self.cb_where == self.GRB.Callback.MIPNODE:
|
||||
return self.model.cbGetNodeRel(var)
|
||||
elif self.cb_where is None:
|
||||
return var.x
|
||||
if self.is_infeasible():
|
||||
return None
|
||||
else:
|
||||
return var.x
|
||||
else:
|
||||
raise Exception(
|
||||
"get_value cannot be called from cb_where=%s" % self.cb_where
|
||||
)
|
||||
|
||||
def get_variables(self):
|
||||
def get_empty_solution(self) -> Dict:
|
||||
self._raise_if_callback()
|
||||
variables = {}
|
||||
solution: Dict = {}
|
||||
for (varname, vardict) in self._all_vars.items():
|
||||
variables[varname] = []
|
||||
solution[varname] = {}
|
||||
for (idx, var) in vardict.items():
|
||||
variables[varname] += [idx]
|
||||
return variables
|
||||
solution[varname][idx] = None
|
||||
return solution
|
||||
|
||||
def add_constraint(self, constraint, name=""):
|
||||
if type(constraint) is tuple:
|
||||
@@ -325,7 +329,7 @@ class GurobiSolver(InternalSolver):
|
||||
else:
|
||||
raise Exception("Unknown sense: %s" % sense)
|
||||
|
||||
def get_inequality_slacks(self):
|
||||
def get_inequality_slacks(self) -> Dict[str, float]:
|
||||
ineqs = [c for c in self.model.getConstrs() if c.sense != "="]
|
||||
return {c.ConstrName: c.Slack for c in ineqs}
|
||||
|
||||
@@ -341,7 +345,7 @@ class GurobiSolver(InternalSolver):
|
||||
c = self.model.getConstrByName(cid)
|
||||
c.RHS = rhs
|
||||
|
||||
def relax(self):
|
||||
def relax(self) -> None:
|
||||
self.model = self.model.relax()
|
||||
self._update_vars()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user