mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Implement more compact get_variables
This commit is contained in:
@@ -6,11 +6,11 @@ import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
from random import randint
|
||||
from typing import List, Any, Dict, Optional, Hashable
|
||||
from typing import List, Any, Dict, Optional, Hashable, Tuple, cast, TYPE_CHECKING
|
||||
|
||||
from overrides import overrides
|
||||
|
||||
from miplearn.features import Constraint, Variable
|
||||
from miplearn.features import Constraint, Variable, VariableFeatures
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.solvers import _RedirectOutput
|
||||
from miplearn.solvers.internal import (
|
||||
@@ -27,6 +27,9 @@ from miplearn.types import (
|
||||
Solution,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import gurobipy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -69,12 +72,12 @@ class GurobiSolver(InternalSolver):
|
||||
self._has_mip_solution = False
|
||||
|
||||
self._varname_to_var: Dict[str, "gurobipy.Var"] = {}
|
||||
self._gp_vars: List["gurobipy.Var"] = []
|
||||
self._var_names: List[str] = []
|
||||
self._var_types: List[str] = []
|
||||
self._var_lbs: List[float] = []
|
||||
self._var_ubs: List[float] = []
|
||||
self._var_obj_coeffs: List[float] = []
|
||||
self._gp_vars: Tuple["gurobipy.Var", ...] = tuple()
|
||||
self._var_names: Tuple[str, ...] = tuple()
|
||||
self._var_types: Tuple[str, ...] = tuple()
|
||||
self._var_lbs: Tuple[float, ...] = tuple()
|
||||
self._var_ubs: Tuple[float, ...] = tuple()
|
||||
self._var_obj_coeffs: Tuple[float, ...] = tuple()
|
||||
|
||||
if self.lazy_cb_frequency == 1:
|
||||
self.lazy_cb_where = [self.gp.GRB.Callback.MIPSOL]
|
||||
@@ -267,10 +270,27 @@ class GurobiSolver(InternalSolver):
|
||||
"upper_bound",
|
||||
"user_features",
|
||||
"value",
|
||||
# new attributes
|
||||
"names",
|
||||
"basis_status",
|
||||
"categories",
|
||||
"lower_bounds",
|
||||
"obj_coeffs",
|
||||
"reduced_costs",
|
||||
"sa_lb_down",
|
||||
"sa_lb_up",
|
||||
"sa_obj_down",
|
||||
"sa_obj_up",
|
||||
"sa_ub_down",
|
||||
"sa_ub_up",
|
||||
"types",
|
||||
"upper_bounds",
|
||||
"user_features",
|
||||
"values",
|
||||
]
|
||||
|
||||
@overrides
|
||||
def get_variables(
|
||||
def get_variables_old(
|
||||
self,
|
||||
with_static: bool = True,
|
||||
with_sa: bool = True,
|
||||
@@ -352,6 +372,76 @@ class GurobiSolver(InternalSolver):
|
||||
variables[names[i]] = var
|
||||
return variables
|
||||
|
||||
@overrides
|
||||
def get_variables(
|
||||
self,
|
||||
with_static: bool = True,
|
||||
with_sa: bool = True,
|
||||
) -> VariableFeatures:
|
||||
model = self.model
|
||||
assert model is not None
|
||||
|
||||
def _parse_gurobi_vbasis(b: int) -> str:
|
||||
if b == 0:
|
||||
return "B"
|
||||
elif b == -1:
|
||||
return "L"
|
||||
elif b == -2:
|
||||
return "U"
|
||||
elif b == -3:
|
||||
return "S"
|
||||
else:
|
||||
raise Exception(f"unknown vbasis: {basis_status}")
|
||||
|
||||
names, upper_bounds, lower_bounds, types, values = None, None, None, None, None
|
||||
obj_coeffs, reduced_costs, basis_status = None, None, None
|
||||
sa_obj_up, sa_ub_up, sa_lb_up = None, None, None
|
||||
sa_obj_down, sa_ub_down, sa_lb_down = None, None, None
|
||||
|
||||
if with_static:
|
||||
names = self._var_names
|
||||
upper_bounds = self._var_ubs
|
||||
lower_bounds = self._var_lbs
|
||||
types = self._var_types
|
||||
obj_coeffs = self._var_obj_coeffs
|
||||
|
||||
if self._has_lp_solution:
|
||||
reduced_costs = tuple(model.getAttr("rc", self._gp_vars))
|
||||
basis_status = tuple(
|
||||
map(
|
||||
_parse_gurobi_vbasis,
|
||||
model.getAttr("vbasis", self._gp_vars),
|
||||
)
|
||||
)
|
||||
|
||||
if with_sa:
|
||||
sa_obj_up = tuple(model.getAttr("saobjUp", self._gp_vars))
|
||||
sa_obj_down = tuple(model.getAttr("saobjLow", self._gp_vars))
|
||||
sa_ub_up = tuple(model.getAttr("saubUp", self._gp_vars))
|
||||
sa_ub_down = tuple(model.getAttr("saubLow", self._gp_vars))
|
||||
sa_lb_up = tuple(model.getAttr("salbUp", self._gp_vars))
|
||||
sa_lb_down = tuple(model.getAttr("salbLow", self._gp_vars))
|
||||
|
||||
if model.solCount > 0:
|
||||
values = tuple(model.getAttr("x", self._gp_vars))
|
||||
|
||||
return VariableFeatures(
|
||||
names=names,
|
||||
upper_bounds=upper_bounds,
|
||||
lower_bounds=lower_bounds,
|
||||
types=types,
|
||||
obj_coeffs=obj_coeffs,
|
||||
reduced_costs=reduced_costs,
|
||||
basis_status=basis_status,
|
||||
sa_obj_up=sa_obj_up,
|
||||
sa_obj_down=sa_obj_down,
|
||||
sa_ub_up=sa_ub_up,
|
||||
sa_ub_down=sa_ub_down,
|
||||
sa_lb_up=sa_lb_up,
|
||||
sa_lb_down=sa_lb_down,
|
||||
values=values,
|
||||
)
|
||||
|
||||
@overrides
|
||||
def is_constraint_satisfied(self, constr: Constraint, tol: float = 1e-6) -> bool:
|
||||
assert constr.lhs is not None
|
||||
@@ -589,12 +679,12 @@ class GurobiSolver(InternalSolver):
|
||||
|
||||
def _update_vars(self) -> None:
|
||||
assert self.model is not None
|
||||
gp_vars = self.model.getVars()
|
||||
var_names = self.model.getAttr("varName", gp_vars)
|
||||
var_types = self.model.getAttr("vtype", gp_vars)
|
||||
var_ubs = self.model.getAttr("ub", gp_vars)
|
||||
var_lbs = self.model.getAttr("lb", gp_vars)
|
||||
var_obj_coeffs = self.model.getAttr("obj", gp_vars)
|
||||
gp_vars: List["gurobipy.Var"] = self.model.getVars()
|
||||
var_names: List[str] = self.model.getAttr("varName", gp_vars)
|
||||
var_types: List[str] = self.model.getAttr("vtype", gp_vars)
|
||||
var_ubs: List[float] = self.model.getAttr("ub", gp_vars)
|
||||
var_lbs: List[float] = self.model.getAttr("lb", gp_vars)
|
||||
var_obj_coeffs: List[float] = self.model.getAttr("obj", gp_vars)
|
||||
varname_to_var: Dict = {}
|
||||
for (i, gp_var) in enumerate(gp_vars):
|
||||
assert var_names[i] not in varname_to_var, (
|
||||
@@ -617,12 +707,12 @@ class GurobiSolver(InternalSolver):
|
||||
)
|
||||
varname_to_var[var_names[i]] = gp_var
|
||||
self._varname_to_var = varname_to_var
|
||||
self._gp_vars = gp_vars
|
||||
self._var_names = var_names
|
||||
self._var_types = var_types
|
||||
self._var_lbs = var_lbs
|
||||
self._var_ubs = var_ubs
|
||||
self._var_obj_coeffs = var_obj_coeffs
|
||||
self._gp_vars = tuple(gp_vars)
|
||||
self._var_names = tuple(var_names)
|
||||
self._var_types = tuple(var_types)
|
||||
self._var_lbs = tuple(var_lbs)
|
||||
self._var_ubs = tuple(var_ubs)
|
||||
self._var_obj_coeffs = tuple(var_obj_coeffs)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user