mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Add types to solvers
This commit is contained in:
@@ -4,13 +4,13 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, List
|
from typing import Any, List, TextIO, cast
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _RedirectOutput:
|
class _RedirectOutput:
|
||||||
def __init__(self, streams: List[Any]):
|
def __init__(self, streams: List[Any]) -> None:
|
||||||
self.streams = streams
|
self.streams = streams
|
||||||
|
|
||||||
def write(self, data: Any) -> None:
|
def write(self, data: Any) -> None:
|
||||||
@@ -21,13 +21,18 @@ class _RedirectOutput:
|
|||||||
for stream in self.streams:
|
for stream in self.streams:
|
||||||
stream.flush()
|
stream.flush()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> Any:
|
||||||
self._original_stdout = sys.stdout
|
self._original_stdout = sys.stdout
|
||||||
self._original_stderr = sys.stderr
|
self._original_stderr = sys.stderr
|
||||||
sys.stdout = self
|
sys.stdout = cast(TextIO, self)
|
||||||
sys.stderr = self
|
sys.stderr = cast(TextIO, self)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, _type, _value, _traceback):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
_type: Any,
|
||||||
|
_value: Any,
|
||||||
|
_traceback: Any,
|
||||||
|
) -> None:
|
||||||
sys.stdout = self._original_stdout
|
sys.stdout = self._original_stdout
|
||||||
sys.stderr = self._original_stderr
|
sys.stderr = self._original_stderr
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from miplearn.types import (
|
|||||||
UserCutCallback,
|
UserCutCallback,
|
||||||
Solution,
|
Solution,
|
||||||
VariableName,
|
VariableName,
|
||||||
|
Constraint,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -158,7 +159,7 @@ class GurobiSolver(InternalSolver):
|
|||||||
iteration_cb = lambda: False
|
iteration_cb = lambda: False
|
||||||
|
|
||||||
# Create callback wrapper
|
# Create callback wrapper
|
||||||
def cb_wrapper(cb_model, cb_where):
|
def cb_wrapper(cb_model: Any, cb_where: int) -> None:
|
||||||
try:
|
try:
|
||||||
self.cb_where = cb_where
|
self.cb_where = cb_where
|
||||||
if lazy_cb is not None and cb_where in self.lazy_cb_where:
|
if lazy_cb is not None and cb_where in self.lazy_cb_where:
|
||||||
@@ -323,7 +324,8 @@ class GurobiSolver(InternalSolver):
|
|||||||
var.ub = value
|
var.ub = value
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_constraint_ids(self):
|
def get_constraint_ids(self) -> List[str]:
|
||||||
|
assert self.model is not None
|
||||||
self._raise_if_callback()
|
self._raise_if_callback()
|
||||||
self.model.update()
|
self.model.update()
|
||||||
return [c.ConstrName for c in self.model.getConstrs()]
|
return [c.ConstrName for c in self.model.getConstrs()]
|
||||||
@@ -344,15 +346,20 @@ class GurobiSolver(InternalSolver):
|
|||||||
return lhs
|
return lhs
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def extract_constraint(self, cid):
|
def extract_constraint(self, cid: str) -> Constraint:
|
||||||
self._raise_if_callback()
|
self._raise_if_callback()
|
||||||
|
assert self.model is not None
|
||||||
constr = self.model.getConstrByName(cid)
|
constr = self.model.getConstrByName(cid)
|
||||||
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName)
|
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName)
|
||||||
self.model.remove(constr)
|
self.model.remove(constr)
|
||||||
return cobj
|
return cobj
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def is_constraint_satisfied(self, cobj, tol=1e-6):
|
def is_constraint_satisfied(
|
||||||
|
self,
|
||||||
|
cobj: Constraint,
|
||||||
|
tol: float = 1e-6,
|
||||||
|
) -> bool:
|
||||||
lhs, sense, rhs, name = cobj
|
lhs, sense, rhs, name = cobj
|
||||||
if self.cb_where is not None:
|
if self.cb_where is not None:
|
||||||
lhs_value = lhs.getConstant()
|
lhs_value = lhs.getConstant()
|
||||||
@@ -416,13 +423,13 @@ class GurobiSolver(InternalSolver):
|
|||||||
value = matches[0]
|
value = matches[0]
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self) -> Dict:
|
||||||
return {
|
return {
|
||||||
"params": self.params,
|
"params": self.params,
|
||||||
"lazy_cb_where": self.lazy_cb_where,
|
"lazy_cb_where": self.lazy_cb_where,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state: Dict) -> None:
|
||||||
self.params = state["params"]
|
self.params = state["params"]
|
||||||
self.lazy_cb_where = state["lazy_cb_where"]
|
self.lazy_cb_where = state["lazy_cb_where"]
|
||||||
self.instance = None
|
self.instance = None
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, List, Any, cast, Callable, Dict
|
from typing import Optional, List, Any, cast, Callable, Dict, Tuple
|
||||||
|
|
||||||
from p_tqdm import p_map
|
from p_tqdm import p_map
|
||||||
|
|
||||||
@@ -37,10 +37,14 @@ class _GlobalVariables:
|
|||||||
_GLOBAL = [_GlobalVariables()]
|
_GLOBAL = [_GlobalVariables()]
|
||||||
|
|
||||||
|
|
||||||
def _parallel_solve(idx):
|
def _parallel_solve(
|
||||||
|
idx: int,
|
||||||
|
) -> Tuple[Optional[LearningSolveStats], Optional[Instance]]:
|
||||||
solver = _GLOBAL[0].solver
|
solver = _GLOBAL[0].solver
|
||||||
instances = _GLOBAL[0].instances
|
instances = _GLOBAL[0].instances
|
||||||
discard_outputs = _GLOBAL[0].discard_outputs
|
discard_outputs = _GLOBAL[0].discard_outputs
|
||||||
|
assert solver is not None
|
||||||
|
assert instances is not None
|
||||||
try:
|
try:
|
||||||
stats = solver.solve(
|
stats = solver.solve(
|
||||||
instances[idx],
|
instances[idx],
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ class BasePyomoSolver(InternalSolver):
|
|||||||
self._pyomo_solver.update_var(var)
|
self._pyomo_solver.update_var(var)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def add_constraint(self, constraint):
|
def add_constraint(self, constraint: Any) -> Any:
|
||||||
self._pyomo_solver.add_constraint(constraint)
|
self._pyomo_solver.add_constraint(constraint)
|
||||||
self._update_constrs()
|
self._update_constrs()
|
||||||
|
|
||||||
@@ -261,7 +261,7 @@ class BasePyomoSolver(InternalSolver):
|
|||||||
return int(value)
|
return int(value)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_constraint_ids(self):
|
def get_constraint_ids(self) -> List[str]:
|
||||||
return list(self._cname_to_constr.keys())
|
return list(self._cname_to_constr.keys())
|
||||||
|
|
||||||
def _get_warm_start_regexp(self) -> Optional[str]:
|
def _get_warm_start_regexp(self) -> Optional[str]:
|
||||||
@@ -332,7 +332,7 @@ class BasePyomoSolver(InternalSolver):
|
|||||||
return self._termination_condition == TerminationCondition.infeasible
|
return self._termination_condition == TerminationCondition.infeasible
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_dual(self, cid):
|
def get_dual(self, cid: str) -> float:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ class CplexPyomoSolver(BasePyomoSolver):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def _get_warm_start_regexp(self):
|
def _get_warm_start_regexp(self) -> str:
|
||||||
return "MIP start .* with objective ([0-9.e+-]*)\\."
|
return "MIP start .* with objective ([0-9.e+-]*)\\."
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def _get_node_count_regexp(self):
|
def _get_node_count_regexp(self) -> str:
|
||||||
return "^[ *] *([0-9]+)"
|
return "^[ *] *([0-9]+)"
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
|||||||
Reference in New Issue
Block a user