diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 0af1b9a..f40e9be 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -154,16 +154,55 @@ class GurobiSolver(InternalSolver): @overrides def get_constraints(self) -> Dict[str, Constraint]: - assert self.model is not None + model = self.model + assert model is not None self._raise_if_callback() if self._dirty: - self.model.update() + model.update() self._dirty = False + gp_constrs = model.getConstrs() + var_names = model.getAttr("varName", model.getVars()) + constr_names = model.getAttr("constrName", gp_constrs) + rhs = model.getAttr("rhs", gp_constrs) + sense = model.getAttr("sense", gp_constrs) + dual_value = None + sa_rhs_up = None + sa_rhs_down = None + slack = None + basis_status = None + if self._has_lp_solution: + dual_value = model.getAttr("pi", gp_constrs) + sa_rhs_up = model.getAttr("saRhsUp", gp_constrs) + sa_rhs_down = model.getAttr("saRhsLow", gp_constrs) + basis_status = model.getAttr("cbasis", gp_constrs) + if self._has_lp_solution or self._has_mip_solution: + slack = model.getAttr("slack", gp_constrs) constraints: Dict[str, Constraint] = {} - for c in self.model.getConstrs(): - constr = self._parse_gurobi_constraint(c) - assert c.constrName not in constraints - constraints[c.constrName] = constr + for (i, gp_constr) in enumerate(gp_constrs): + expr = model.getRow(gp_constr) + lhs = {} + for j in range(expr.size()): + lhs[var_names[expr.getVar(j).index]] = expr.getCoeff(j) + assert ( + constr_names[i] not in constraints + ), f"Duplicated constraint name detected: {constr_names[i]}" + constraint = Constraint(lhs=lhs, rhs=rhs[i], sense=sense[i]) + if dual_value is not None: + assert sa_rhs_up is not None + assert sa_rhs_down is not None + assert basis_status is not None + constraint.dual_value = dual_value[i] + constraint.sa_rhs_up = sa_rhs_up[i] + constraint.sa_rhs_down = sa_rhs_down[i] + if gp_constr.cbasis == 0: + constraint.basis_status = "B" + elif gp_constr.cbasis == -1: + constraint.basis_status = "N" + else: + raise Exception(f"unknown cbasis: {gp_constr.cbasis}") + if slack is not None: + constraint.slack = slack[i] + constraints[constr_names[i]] = constraint return constraints @overrides @@ -553,31 +592,6 @@ class GurobiSolver(InternalSolver): self.model = None self.cb_where = None - def _parse_gurobi_constraint(self, gp_constr: Any) -> Constraint: - assert self.model is not None - expr = self.model.getRow(gp_constr) - lhs: Dict[str, float] = {} - for i in range(expr.size()): - lhs[expr.getVar(i).varName] = expr.getCoeff(i) - constr = Constraint( - rhs=gp_constr.rhs, - lhs=lhs, - sense=gp_constr.sense, - ) - if self._has_lp_solution: - constr.dual_value = gp_constr.pi - constr.sa_rhs_up = gp_constr.sarhsup - constr.sa_rhs_down = gp_constr.sarhslow - if gp_constr.cbasis == 0: - constr.basis_status = "B" - elif gp_constr.cbasis == -1: - constr.basis_status = "N" - else: - raise Exception(f"unknown cbasis: {gp_constr.cbasis}") - if self._has_lp_solution or self._has_mip_solution: - constr.slack = gp_constr.slack - return constr - class GurobiTestInstanceInfeasible(Instance): @overrides