diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index e6e04e1..7060c2a 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -458,52 +458,31 @@ class GurobiSolver(InternalSolver): var.type = self._original_vtype[gp_var] if self._has_lp_solution: - var.reduced_cost = gp_var.rc - var.sa_obj_up = gp_var.saobjUp - var.sa_obj_down = gp_var.saobjLow - var.sa_ub_up = gp_var.saubUp - var.sa_ub_down = gp_var.saubLow - var.sa_lb_up = gp_var.salbUp - var.sa_lb_down = gp_var.salbLow - vbasis = gp_var.vbasis - if vbasis == 0: - var.basis_status = "B" - elif vbasis == -1: - var.basis_status = "L" - elif vbasis == -2: - var.basis_status = "U" - elif vbasis == -3: - var.basis_status = "S" - else: - raise Exception(f"unknown vbasis: {vbasis}") + self._parse_gurobi_var_lp(gp_var, var) if self._has_lp_solution or self._has_mip_solution: var.value = gp_var.x return var - def _parse_gurobi_constraint(self, gp_constr: Any) -> Constraint: - assert self.model is not None - expr = self.model.getRow(gp_constr) - lhs: Dict[str, float] = {} - for i in range(expr.size()): - lhs[expr.getVar(i).varName] = expr.getCoeff(i) - constr = Constraint( - rhs=gp_constr.rhs, - lhs=lhs, - sense=gp_constr.sense, - ) - if self._has_lp_solution: - constr.dual_value = gp_constr.pi - constr.sa_rhs_up = gp_constr.sarhsup - constr.sa_rhs_down = gp_constr.sarhslow - if gp_constr.cbasis == 0: - constr.basis_status = "B" - elif gp_constr.cbasis == -1: - constr.basis_status = "N" - else: - raise Exception(f"unknown cbasis: {gp_constr.cbasis}") - if self._has_lp_solution or self._has_mip_solution: - constr.slack = gp_constr.slack - return constr + @staticmethod + def _parse_gurobi_var_lp(gp_var, var): + var.reduced_cost = gp_var.rc + var.sa_obj_up = gp_var.saobjUp + var.sa_obj_down = gp_var.saobjLow + var.sa_ub_up = gp_var.saubUp + var.sa_ub_down = gp_var.saubLow + var.sa_lb_up = gp_var.salbUp + var.sa_lb_down = gp_var.salbLow + vbasis = gp_var.vbasis + if vbasis == 0: + var.basis_status = "B" + elif vbasis == -1: + var.basis_status = "L" + elif vbasis == -2: + var.basis_status = "U" + elif vbasis == -3: + var.basis_status = "S" + else: + raise Exception(f"unknown vbasis: {vbasis}") def _raise_if_callback(self) -> None: if self.cb_where is not None: @@ -541,6 +520,31 @@ class GurobiSolver(InternalSolver): self.model = None self.cb_where = None + def _parse_gurobi_constraint(self, gp_constr: Any) -> Constraint: + assert self.model is not None + expr = self.model.getRow(gp_constr) + lhs: Dict[str, float] = {} + for i in range(expr.size()): + lhs[expr.getVar(i).varName] = expr.getCoeff(i) + constr = Constraint( + rhs=gp_constr.rhs, + lhs=lhs, + sense=gp_constr.sense, + ) + if self._has_lp_solution: + constr.dual_value = gp_constr.pi + constr.sa_rhs_up = gp_constr.sarhsup + constr.sa_rhs_down = gp_constr.sarhslow + if gp_constr.cbasis == 0: + constr.basis_status = "B" + elif gp_constr.cbasis == -1: + constr.basis_status = "N" + else: + raise Exception(f"unknown cbasis: {gp_constr.cbasis}") + if self._has_lp_solution or self._has_mip_solution: + constr.slack = gp_constr.slack + return constr + class GurobiTestInstanceInfeasible(Instance): @overrides diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index 86d0909..00a74ca 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -505,7 +505,10 @@ class BasePyomoSolver(InternalSolver): self._varname_to_var = {} for var in self.model.component_objects(Var): for idx in var: - self._varname_to_var[f"{var.name}[{idx}]"] = var[idx] + varname = f"{var.name}[{idx}]" + if idx is None: + varname = var.name + self._varname_to_var[varname] = var[idx] self._all_vars += [var[idx]] if var[idx].domain == pyomo.core.base.set_types.Binary: self._bin_vars += [var[idx]] diff --git a/miplearn/solvers/pyomo/gurobi.py b/miplearn/solvers/pyomo/gurobi.py index 6a242a6..0a4152b 100644 --- a/miplearn/solvers/pyomo/gurobi.py +++ b/miplearn/solvers/pyomo/gurobi.py @@ -3,12 +3,14 @@ # Released under the modified BSD license. See COPYING.md for more details. import logging -from typing import Optional +from typing import Optional, List, Dict from overrides import overrides from pyomo import environ as pe from scipy.stats import randint +from miplearn.features import Variable +from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.pyomo.base import BasePyomoSolver from miplearn.types import SolverParams, BranchPriorities @@ -39,16 +41,8 @@ class GurobiPyomoSolver(BasePyomoSolver): ) @overrides - def _extract_node_count(self, log: str) -> int: - return max(1, int(self._pyomo_solver._solver_model.getAttr("NodeCount"))) - - @overrides - def _get_warm_start_regexp(self) -> str: - return "MIP start with objective ([0-9.e+-]*)" - - @overrides - def _get_node_count_regexp(self) -> Optional[str]: - return None + def clone(self) -> "GurobiPyomoSolver": + return GurobiPyomoSolver(params=self.params) @overrides def set_branching_priorities(self, priorities: BranchPriorities) -> None: @@ -62,5 +56,44 @@ class GurobiPyomoSolver(BasePyomoSolver): gvar.setAttr(GRB.Attr.BranchPriority, int(round(priority))) @overrides - def clone(self) -> "GurobiPyomoSolver": - return GurobiPyomoSolver(params=self.params) + def get_variables(self) -> Dict[str, Variable]: + variables = super().get_variables() + if self._has_lp_solution: + for (varname, var) in variables.items(): + pvar = self._varname_to_var[varname] + gvar = self._pyomo_solver._pyomo_var_to_solver_var_map[pvar] + GurobiSolver._parse_gurobi_var_lp(gvar, var) + + return variables + + @overrides + def get_variable_attrs(self) -> List[str]: + return [ + "basis_status", + "category", + "lower_bound", + "obj_coeff", + "reduced_cost", + "sa_lb_down", + "sa_lb_up", + "sa_obj_down", + "sa_obj_up", + "sa_ub_down", + "sa_ub_up", + "type", + "upper_bound", + "user_features", + "value", + ] + + @overrides + def _extract_node_count(self, log: str) -> int: + return max(1, int(self._pyomo_solver._solver_model.getAttr("NodeCount"))) + + @overrides + def _get_warm_start_regexp(self) -> str: + return "MIP start with objective ([0-9.e+-]*)" + + @overrides + def _get_node_count_regexp(self) -> Optional[str]: + return None