Refer to variables by varname instead of (vname, index)

This commit is contained in:
2021-04-07 10:56:31 -05:00
parent 856b595d5e
commit 1cf6124757
22 changed files with 467 additions and 516 deletions

View File

@@ -6,7 +6,9 @@ import re
import sys
from io import StringIO
from random import randint
from typing import List, Any, Dict, Optional, cast, Tuple, Union
from typing import List, Any, Dict, Optional
from overrides import overrides
from miplearn.instance.base import Instance
from miplearn.solvers import _RedirectOutput
@@ -17,7 +19,12 @@ from miplearn.solvers.internal import (
LazyCallback,
MIPSolveStats,
)
from miplearn.types import VarIndex, SolverParams, Solution, UserCutCallback
from miplearn.types import (
SolverParams,
UserCutCallback,
Solution,
VariableName,
)
logger = logging.getLogger(__name__)
@@ -52,8 +59,8 @@ class GurobiSolver(InternalSolver):
self.instance: Optional[Instance] = None
self.model: Optional["gurobipy.Model"] = None
self.params: SolverParams = params
self._all_vars: Dict = {}
self._bin_vars: Optional[Dict[str, Dict[VarIndex, "gurobipy.Var"]]] = None
self.varname_to_var: Dict[str, "gurobipy.Var"] = {}
self.bin_vars: List["gurobipy.Var"] = []
self.cb_where: Optional[int] = None
assert lazy_cb_frequency in [1, 2]
@@ -65,6 +72,7 @@ class GurobiSolver(InternalSolver):
self.gp.GRB.Callback.MIPNODE,
]
@overrides
def set_instance(
self,
instance: Instance,
@@ -85,30 +93,20 @@ class GurobiSolver(InternalSolver):
def _update_vars(self) -> None:
assert self.model is not None
self._all_vars = {}
self._bin_vars = {}
idx: VarIndex
self.varname_to_var.clear()
self.bin_vars.clear()
for var in self.model.getVars():
m = re.search(r"([^[]*)\[(.*)]", var.varName)
if m is None:
name = var.varName
idx = (0,)
else:
name = m.group(1)
parts = m.group(2).split(",")
idx = cast(
Tuple[Union[str, int]],
tuple(int(k) if k.isdecimal() else str(k) for k in parts),
)
if len(idx) == 1:
idx = idx[0]
if name not in self._all_vars:
self._all_vars[name] = {}
self._all_vars[name][idx] = var
if var.vtype != "C":
if name not in self._bin_vars:
self._bin_vars[name] = {}
self._bin_vars[name][idx] = var
assert var.varName not in self.varname_to_var, (
f"Duplicated variable name detected: {var.varName}. "
f"Unique variable names are currently required."
)
self.varname_to_var[var.varName] = var
assert var.vtype in ["B", "C"], (
"Only binary and continuous variables are currently supported. "
"Variable {var.varName} has type {var.vtype}."
)
if var.vtype == "B":
self.bin_vars.append(var)
def _apply_params(self, streams: List[Any]) -> None:
assert self.model is not None
@@ -118,6 +116,7 @@ class GurobiSolver(InternalSolver):
if "seed" not in [k.lower() for k in self.params.keys()]:
self.model.setParam("Seed", randint(0, 1_000_000))
@overrides
def solve_lp(
self,
tee: bool = False,
@@ -128,17 +127,14 @@ class GurobiSolver(InternalSolver):
streams += [sys.stdout]
self._apply_params(streams)
assert self.model is not None
assert self._bin_vars is not None
for (varname, vardict) in self._bin_vars.items():
for (idx, var) in vardict.items():
var.vtype = self.gp.GRB.CONTINUOUS
var.lb = 0.0
var.ub = 1.0
for var in self.bin_vars:
var.vtype = self.gp.GRB.CONTINUOUS
var.lb = 0.0
var.ub = 1.0
with _RedirectOutput(streams):
self.model.optimize()
for (varname, vardict) in self._bin_vars.items():
for (idx, var) in vardict.items():
var.vtype = self.gp.GRB.BINARY
for var in self.bin_vars:
var.vtype = self.gp.GRB.BINARY
log = streams[0].getvalue()
opt_value = None
if not self.is_infeasible():
@@ -148,6 +144,7 @@ class GurobiSolver(InternalSolver):
"LP log": log,
}
@overrides
def solve(
self,
tee: bool = False,
@@ -218,33 +215,30 @@ class GurobiSolver(InternalSolver):
}
return stats
@overrides
def get_solution(self) -> Optional[Solution]:
self._raise_if_callback()
assert self.model is not None
if self.model.solCount == 0:
return None
solution: Solution = {}
for (varname, vardict) in self._all_vars.items():
solution[varname] = {}
for (idx, var) in vardict.items():
solution[varname][idx] = var.x
return solution
return {v.varName: v.x for v in self.model.getVars()}
@overrides
def get_variable_names(self) -> List[VariableName]:
self._raise_if_callback()
assert self.model is not None
return [v.varName for v in self.model.getVars()]
@overrides
def set_warm_start(self, solution: Solution) -> None:
self._raise_if_callback()
self._clear_warm_start()
count_fixed, count_total = 0, 0
for (varname, vardict) in solution.items():
for (idx, value) in vardict.items():
count_total += 1
if value is not None:
count_fixed += 1
self._all_vars[varname][idx].start = value
logger.info(
"Setting start values for %d variables (out of %d)"
% (count_fixed, count_total)
)
for (var_name, value) in solution.items():
var = self.varname_to_var[var_name]
if value is not None:
var.start = value
@overrides
def get_sense(self) -> str:
assert self.model is not None
if self.model.modelSense == 1:
@@ -252,18 +246,12 @@ class GurobiSolver(InternalSolver):
else:
return "max"
def get_value(
self,
var_name: str,
index: VarIndex,
) -> Optional[float]:
var = self._all_vars[var_name][index]
return self._get_value(var)
@overrides
def is_infeasible(self) -> bool:
assert self.model is not None
return self.model.status in [self.gp.GRB.INFEASIBLE, self.gp.GRB.INF_OR_UNBD]
@overrides
def get_dual(self, cid: str) -> float:
assert self.model is not None
c = self.model.getConstrByName(cid)
@@ -288,15 +276,7 @@ class GurobiSolver(InternalSolver):
"get_value cannot be called from cb_where=%s" % self.cb_where
)
def get_empty_solution(self) -> Solution:
self._raise_if_callback()
solution: Solution = {}
for (varname, vardict) in self._all_vars.items():
solution[varname] = {}
for (idx, var) in vardict.items():
solution[varname][idx] = None
return solution
@overrides
def add_constraint(
self,
constraint: Any,
@@ -321,36 +301,39 @@ class GurobiSolver(InternalSolver):
else:
self.model.addConstr(constraint, name=name)
@overrides
def add_cut(self, cobj: Any) -> None:
assert self.model is not None
assert self.cb_where == self.gp.GRB.Callback.MIPNODE
self.model.cbCut(cobj)
def _clear_warm_start(self) -> None:
for (varname, vardict) in self._all_vars.items():
for (idx, var) in vardict.items():
var.start = self.gp.GRB.UNDEFINED
for var in self.varname_to_var.values():
var.start = self.gp.GRB.UNDEFINED
@overrides
def fix(self, solution: Solution) -> None:
self._raise_if_callback()
for (varname, vardict) in solution.items():
for (idx, value) in vardict.items():
if value is None:
continue
var = self._all_vars[varname][idx]
var.vtype = self.gp.GRB.CONTINUOUS
var.lb = value
var.ub = value
for (varname, value) in solution.items():
if value is None:
continue
var = self.varname_to_var[varname]
var.vtype = self.gp.GRB.CONTINUOUS
var.lb = value
var.ub = value
@overrides
def get_constraint_ids(self):
self._raise_if_callback()
self.model.update()
return [c.ConstrName for c in self.model.getConstrs()]
@overrides
def get_constraint_rhs(self, cid: str) -> float:
assert self.model is not None
return self.model.getConstrByName(cid).rhs
@overrides
def get_constraint_lhs(self, cid: str) -> Dict[str, float]:
assert self.model is not None
constr = self.model.getConstrByName(cid)
@@ -360,6 +343,7 @@ class GurobiSolver(InternalSolver):
lhs[expr.getVar(i).varName] = expr.getCoeff(i)
return lhs
@overrides
def extract_constraint(self, cid):
self._raise_if_callback()
constr = self.model.getConstrByName(cid)
@@ -367,6 +351,7 @@ class GurobiSolver(InternalSolver):
self.model.remove(constr)
return cobj
@overrides
def is_constraint_satisfied(self, cobj, tol=1e-6):
lhs, sense, rhs, name = cobj
if self.cb_where is not None:
@@ -386,21 +371,25 @@ class GurobiSolver(InternalSolver):
else:
raise Exception("Unknown sense: %s" % sense)
@overrides
def get_inequality_slacks(self) -> Dict[str, float]:
assert self.model is not None
ineqs = [c for c in self.model.getConstrs() if c.sense != "="]
return {c.ConstrName: c.Slack for c in ineqs}
@overrides
def set_constraint_sense(self, cid: str, sense: str) -> None:
assert self.model is not None
c = self.model.getConstrByName(cid)
c.Sense = sense
@overrides
def get_constraint_sense(self, cid: str) -> str:
assert self.model is not None
c = self.model.getConstrByName(cid)
return c.Sense
@overrides
def relax(self) -> None:
assert self.model is not None
self.model.update()
@@ -438,6 +427,4 @@ class GurobiSolver(InternalSolver):
self.lazy_cb_where = state["lazy_cb_where"]
self.instance = None
self.model = None
self._all_vars = None
self._bin_vars = None
self.cb_where = None