mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Add types to internal solvers
This commit is contained in:
@@ -12,6 +12,7 @@ import pyomo
|
||||
from pyomo import environ as pe
|
||||
from pyomo.core import Var, Constraint
|
||||
from pyomo.opt import TerminationCondition
|
||||
from pyomo.opt.base.solvers import SolverFactory
|
||||
|
||||
from miplearn.instance import Instance
|
||||
from miplearn.solvers import RedirectOutput
|
||||
@@ -22,7 +23,7 @@ from miplearn.solvers.internal import (
|
||||
LazyCallback,
|
||||
MIPSolveStats,
|
||||
)
|
||||
from miplearn.types import VarIndex
|
||||
from miplearn.types import VarIndex, SolverParams, Solution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,19 +35,20 @@ class BasePyomoSolver(InternalSolver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
solver_factory,
|
||||
params,
|
||||
):
|
||||
self.instance = None
|
||||
self.model = None
|
||||
self._all_vars = None
|
||||
self._bin_vars = None
|
||||
self._is_warm_start_available = False
|
||||
self._pyomo_solver = solver_factory
|
||||
self._obj_sense = None
|
||||
self._varname_to_var = {}
|
||||
self._cname_to_constr = {}
|
||||
self._termination_condition = None
|
||||
solver_factory: SolverFactory,
|
||||
params: SolverParams,
|
||||
) -> None:
|
||||
self.instance: Optional[Instance] = None
|
||||
self.model: Optional[pe.ConcreteModel] = None
|
||||
self._all_vars: List[pe.Var] = []
|
||||
self._bin_vars: List[pe.Var] = []
|
||||
self._is_warm_start_available: bool = False
|
||||
self._pyomo_solver: SolverFactory = solver_factory
|
||||
self._obj_sense: str = "min"
|
||||
self._varname_to_var: Dict[str, pe.Var] = {}
|
||||
self._cname_to_constr: Dict[str, pe.Constraint] = {}
|
||||
self._termination_condition: str = ""
|
||||
|
||||
for (key, value) in params.items():
|
||||
self._pyomo_solver.options[key] = value
|
||||
|
||||
@@ -88,8 +90,6 @@ class BasePyomoSolver(InternalSolver):
|
||||
streams += [sys.stdout]
|
||||
if iteration_cb is None:
|
||||
iteration_cb = lambda: False
|
||||
self.instance.found_violated_lazy_constraints = []
|
||||
self.instance.found_violated_user_cuts = []
|
||||
while True:
|
||||
logger.debug("Solving MIP...")
|
||||
with RedirectOutput(streams):
|
||||
@@ -121,10 +121,11 @@ class BasePyomoSolver(InternalSolver):
|
||||
}
|
||||
return stats
|
||||
|
||||
def get_solution(self) -> Optional[Dict]:
|
||||
def get_solution(self) -> Optional[Solution]:
|
||||
assert self.model is not None
|
||||
if self.is_infeasible():
|
||||
return None
|
||||
solution: Dict = {}
|
||||
solution: Solution = {}
|
||||
for var in self.model.component_objects(Var):
|
||||
solution[str(var)] = {}
|
||||
for index in var:
|
||||
@@ -133,7 +134,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
solution[str(var)][index] = var[index].value
|
||||
return solution
|
||||
|
||||
def set_warm_start(self, solution: Dict) -> None:
|
||||
def set_warm_start(self, solution: Solution) -> None:
|
||||
self._clear_warm_start()
|
||||
count_total, count_fixed = 0, 0
|
||||
for var_name in solution:
|
||||
@@ -172,8 +173,9 @@ class BasePyomoSolver(InternalSolver):
|
||||
var = self._varname_to_var[var_name]
|
||||
return var[index].value
|
||||
|
||||
def get_empty_solution(self) -> Dict:
|
||||
solution: Dict = {}
|
||||
def get_empty_solution(self) -> Solution:
|
||||
assert self.model is not None
|
||||
solution: Solution = {}
|
||||
for var in self.model.component_objects(Var):
|
||||
svar = str(var)
|
||||
solution[svar] = {}
|
||||
@@ -195,6 +197,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
self._obj_sense = "min"
|
||||
|
||||
def _update_vars(self) -> None:
|
||||
assert self.model is not None
|
||||
self._all_vars = []
|
||||
self._bin_vars = []
|
||||
self._varname_to_var = {}
|
||||
@@ -206,6 +209,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
self._bin_vars += [var[idx]]
|
||||
|
||||
def _update_constrs(self) -> None:
|
||||
assert self.model is not None
|
||||
self._cname_to_constr = {}
|
||||
for constr in self.model.component_objects(Constraint):
|
||||
self._cname_to_constr[constr.name] = constr
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
from typing import Optional
|
||||
|
||||
from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.solvers.pyomo.base import BasePyomoSolver
|
||||
from miplearn.types import SolverParams
|
||||
|
||||
|
||||
class CplexPyomoSolver(BasePyomoSolver):
|
||||
@@ -19,13 +21,19 @@ class CplexPyomoSolver(BasePyomoSolver):
|
||||
{"mip_display": 5} to increase the log verbosity.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
def __init__(
|
||||
self,
|
||||
params: Optional[SolverParams] = None,
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "randomseed" not in params.keys():
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
if "mip_display" not in params.keys():
|
||||
params["mip_display"] = 4
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("cplex_persistent"),
|
||||
params={
|
||||
"randomseed": randint(low=0, high=1000).rvs(),
|
||||
"mip_display": 4,
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
def _get_warm_start_regexp(self):
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.solvers.pyomo.base import BasePyomoSolver
|
||||
from miplearn.types import SolverParams, BranchPriorities
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,28 +25,35 @@ class GurobiPyomoSolver(BasePyomoSolver):
|
||||
{"Threads": 4} to set the number of threads.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
def __init__(
|
||||
self,
|
||||
params: SolverParams = None,
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "seed" not in params.keys():
|
||||
params["seed"] = randint(low=0, high=1000).rvs()
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("gurobi_persistent"),
|
||||
params={
|
||||
"Seed": randint(low=0, high=1000).rvs(),
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
def _extract_node_count(self, log):
|
||||
def _extract_node_count(self, log: str) -> int:
|
||||
return max(1, int(self._pyomo_solver._solver_model.getAttr("NodeCount")))
|
||||
|
||||
def _get_warm_start_regexp(self):
|
||||
def _get_warm_start_regexp(self) -> str:
|
||||
return "MIP start with objective ([0-9.e+-]*)"
|
||||
|
||||
def _get_node_count_regexp(self):
|
||||
def _get_node_count_regexp(self) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def set_branching_priorities(self, priorities):
|
||||
def set_branching_priorities(self, priorities: BranchPriorities) -> None:
|
||||
from gurobipy import GRB
|
||||
|
||||
for varname in priorities.keys():
|
||||
var = self._varname_to_var[varname]
|
||||
for (index, priority) in priorities[varname].items():
|
||||
if priority is None:
|
||||
continue
|
||||
gvar = self._pyomo_solver._pyomo_var_to_solver_var_map[var[index]]
|
||||
gvar.setAttr(GRB.Attr.BranchPriority, int(round(priority)))
|
||||
|
||||
@@ -8,6 +8,7 @@ from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
from miplearn.solvers.pyomo.base import BasePyomoSolver
|
||||
from miplearn.types import SolverParams
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,10 +24,12 @@ class XpressPyomoSolver(BasePyomoSolver):
|
||||
{"Threads": 4} to set the number of threads.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
def __init__(self, params: SolverParams = None) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "randomseed" not in params.keys():
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("xpress_persistent"),
|
||||
params={
|
||||
"randomseed": randint(low=0, high=1000).rvs(),
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user