From e6eca2ee7f488ec5be4d149e78a39b9a69d9ce8a Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Thu, 15 Apr 2021 04:12:10 -0500 Subject: [PATCH] GurobiSolver: Performance improvements --- miplearn/components/component.py | 2 +- miplearn/solvers/gurobi.py | 143 +++++++++++++++++-------------- 2 files changed, 82 insertions(+), 63 deletions(-) diff --git a/miplearn/components/component.py b/miplearn/components/component.py index 986fb3a..63d5dad 100644 --- a/miplearn/components/component.py +++ b/miplearn/components/component.py @@ -169,7 +169,7 @@ class Component(EnforceOverrides): """ pass - def pre_fit(self, pre: List[Any]): + def pre_fit(self, pre: List[Any]) -> None: pass def user_cut_cb( diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index d60ac1d..d6d9890 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -25,8 +25,6 @@ from miplearn.types import ( SolverParams, UserCutCallback, Solution, - VariableName, - Category, ) logger = logging.getLogger(__name__) @@ -66,13 +64,18 @@ class GurobiSolver(InternalSolver): self.params: SolverParams = params self.cb_where: Optional[int] = None self.lazy_cb_frequency = lazy_cb_frequency - self._bin_vars: List["gurobipy.Var"] = [] - self._varname_to_var: Dict[str, "gurobipy.Var"] = {} - self._original_vtype: Dict["gurobipy.Var", str] = {} self._dirty = True self._has_lp_solution = False self._has_mip_solution = False + self._varname_to_var: Dict[str, "gurobipy.Var"] = {} + self._gp_vars: List["gurobipy.Var"] = [] + self._var_names: List[str] = [] + self._var_types: List[str] = [] + self._var_lbs: List[float] = [] + self._var_ubs: List[float] = [] + self._var_obj_coeffs: List[float] = [] + if self.lazy_cb_frequency == 1: self.lazy_cb_where = [self.gp.GRB.Callback.MIPSOL] else: @@ -84,6 +87,8 @@ class GurobiSolver(InternalSolver): @overrides def add_constraint(self, constr: Constraint, name: str) -> None: assert self.model is not None + assert self._varname_to_var is not None + assert constr.lhs is not None lhs = self.gp.quicksum( self._varname_to_var[varname] * coeff @@ -265,11 +270,18 @@ class GurobiSolver(InternalSolver): ] @overrides - def get_variables(self, with_static: bool = True) -> Dict[str, Variable]: + def get_variables( + self, + with_static: bool = True, + with_sa: bool = True, + ) -> Dict[str, Variable]: assert self.model is not None - variables = {} - gp_vars = self.model.getVars() - names = self.model.getAttr("varName", gp_vars) + + names = self._var_names + ub = self._var_ubs + lb = self._var_lbs + obj_coeff = self._var_obj_coeffs + values = None rc = None sa_obj_up = None @@ -279,26 +291,23 @@ class GurobiSolver(InternalSolver): sa_lb_up = None sa_lb_down = None vbasis = None - ub = None - lb = None - obj_coeff = None - if with_static: - lb = self.model.getAttr("lb", gp_vars) - ub = self.model.getAttr("ub", gp_vars) - obj_coeff = self.model.getAttr("obj", gp_vars) + if self.model.solCount > 0: - values = self.model.getAttr("x", gp_vars) + values = self.model.getAttr("x", self._gp_vars) + if self._has_lp_solution: - rc = self.model.getAttr("rc", gp_vars) - sa_obj_up = self.model.getAttr("saobjUp", gp_vars) - sa_obj_down = self.model.getAttr("saobjLow", gp_vars) - sa_ub_up = self.model.getAttr("saubUp", gp_vars) - sa_ub_down = self.model.getAttr("saubLow", gp_vars) - sa_lb_up = self.model.getAttr("salbUp", gp_vars) - sa_lb_down = self.model.getAttr("salbLow", gp_vars) - vbasis = self.model.getAttr("vbasis", gp_vars) + rc = self.model.getAttr("rc", self._gp_vars) + vbasis = self.model.getAttr("vbasis", self._gp_vars) + if with_sa: + sa_obj_up = self.model.getAttr("saobjUp", self._gp_vars) + sa_obj_down = self.model.getAttr("saobjLow", self._gp_vars) + sa_ub_up = self.model.getAttr("saubUp", self._gp_vars) + sa_ub_down = self.model.getAttr("saubLow", self._gp_vars) + sa_lb_up = self.model.getAttr("salbUp", self._gp_vars) + sa_lb_down = self.model.getAttr("salbLow", self._gp_vars) - for (i, gp_var) in enumerate(gp_vars): + variables = {} + for (i, gp_var) in enumerate(self._gp_vars): assert len(names[i]) > 0, "Empty variable name detected." assert ( names[i] not in variables @@ -311,24 +320,12 @@ class GurobiSolver(InternalSolver): var.lower_bound = lb[i] var.upper_bound = ub[i] var.obj_coeff = obj_coeff[i] - var.type = self._original_vtype[gp_var] + var.type = self._var_types[i] if values is not None: var.value = values[i] if rc is not None: - assert sa_obj_up is not None - assert sa_obj_down is not None - assert sa_ub_up is not None - assert sa_ub_down is not None - assert sa_lb_up is not None - assert sa_lb_down is not None assert vbasis is not None var.reduced_cost = rc[i] - var.sa_obj_up = sa_obj_up[i] - var.sa_obj_down = sa_obj_down[i] - var.sa_ub_up = sa_ub_up[i] - var.sa_ub_down = sa_ub_down[i] - var.sa_lb_up = sa_lb_up[i] - var.sa_lb_down = sa_lb_down[i] if vbasis[i] == 0: var.basis_status = "B" elif vbasis[i] == -1: @@ -339,6 +336,19 @@ class GurobiSolver(InternalSolver): var.basis_status = "S" else: raise Exception(f"unknown vbasis: {vbasis}") + if with_sa: + assert sa_obj_up is not None + assert sa_obj_down is not None + assert sa_ub_up is not None + assert sa_ub_down is not None + assert sa_lb_up is not None + assert sa_lb_down is not None + var.sa_obj_up = sa_obj_up[i] + var.sa_obj_down = sa_obj_down[i] + var.sa_ub_up = sa_ub_up[i] + var.sa_ub_down = sa_ub_down[i] + var.sa_lb_up = sa_lb_up[i] + var.sa_lb_down = sa_lb_down[i] variables[names[i]] = var return variables @@ -479,15 +489,17 @@ class GurobiSolver(InternalSolver): streams += [sys.stdout] self._apply_params(streams) assert self.model is not None - for var in self._bin_vars: - var.vtype = self.gp.GRB.CONTINUOUS - var.lb = 0.0 - var.ub = 1.0 + for (i, var) in enumerate(self._gp_vars): + if self._var_types[i] == "B": + var.vtype = self.gp.GRB.CONTINUOUS + var.lb = 0.0 + var.ub = 1.0 with _RedirectOutput(streams): self.model.optimize() self._dirty = False - for var in self._bin_vars: - var.vtype = self.gp.GRB.BINARY + for (i, var) in enumerate(self._gp_vars): + if self._var_types[i] == "B": + var.vtype = self.gp.GRB.BINARY log = streams[0].getvalue() self._has_lp_solution = self.model.solCount > 0 self._has_mip_solution = False @@ -577,33 +589,40 @@ class GurobiSolver(InternalSolver): def _update_vars(self) -> None: assert self.model is not None - self._varname_to_var.clear() - self._original_vtype = {} - self._bin_vars.clear() - for var in self.model.getVars(): - assert var.varName not in self._varname_to_var, ( - f"Duplicated variable name detected: {var.varName}. " - f"Unique variable names are currently required." + gp_vars = self.model.getVars() + var_names = self.model.getAttr("varName", gp_vars) + var_types = self.model.getAttr("vtype", gp_vars) + var_ubs = self.model.getAttr("ub", gp_vars) + var_lbs = self.model.getAttr("lb", gp_vars) + var_obj_coeffs = self.model.getAttr("obj", gp_vars) + varname_to_var: Dict = {} + for (i, gp_var) in enumerate(gp_vars): + assert var_names[i] not in varname_to_var, ( + f"Duplicated variable name detected: {var_names[i]}. " + f"Unique variable var_names are currently required." ) - self._varname_to_var[var.varName] = var - vtype = var.vtype - if vtype == "I": - assert var.ub == 1.0, ( + if var_types[i] == "I": + assert var_ubs[i] == 1.0, ( "Only binary and continuous variables are currently supported. " "Integer variable {var.varName} has upper bound {var.ub}." ) - assert var.lb == 0.0, ( + assert var_lbs[i] == 0.0, ( "Only binary and continuous variables are currently supported. " "Integer variable {var.varName} has lower bound {var.ub}." ) - vtype = "B" - assert vtype in ["B", "C"], ( + var_types[i] = "B" + assert var_types[i] in ["B", "C"], ( "Only binary and continuous variables are currently supported. " "Variable {var.varName} has type {vtype}." ) - self._original_vtype[var] = vtype - if vtype == "B": - self._bin_vars.append(var) + varname_to_var[var_names[i]] = gp_var + self._varname_to_var = varname_to_var + self._gp_vars = gp_vars + self._var_names = var_names + self._var_types = var_types + self._var_lbs = var_lbs + self._var_ubs = var_ubs + self._var_obj_coeffs = var_obj_coeffs def __getstate__(self) -> Dict: return {