Reformat source code with Black; add pre-commit hooks and CI checks

This commit is contained in:
2020-12-05 10:59:33 -06:00
parent 3823931382
commit d99600f101
49 changed files with 1291 additions and 972 deletions

View File

@@ -13,10 +13,11 @@ logger = logging.getLogger(__name__)
class GurobiSolver(InternalSolver):
def __init__(self,
params=None,
lazy_cb_frequency=1,
):
def __init__(
self,
params=None,
lazy_cb_frequency=1,
):
"""
An InternalSolver backed by Gurobi's Python API (without Pyomo).
@@ -33,6 +34,7 @@ class GurobiSolver(InternalSolver):
if params is None:
params = {}
from gurobipy import GRB
self.GRB = GRB
self.instance = None
self.model = None
@@ -44,8 +46,7 @@ class GurobiSolver(InternalSolver):
if lazy_cb_frequency == 1:
self.lazy_cb_where = [self.GRB.Callback.MIPSOL]
else:
self.lazy_cb_where = [self.GRB.Callback.MIPSOL,
self.GRB.Callback.MIPNODE]
self.lazy_cb_where = [self.GRB.Callback.MIPSOL, self.GRB.Callback.MIPNODE]
def set_instance(self, instance, model=None):
self._raise_if_callback()
@@ -70,14 +71,15 @@ class GurobiSolver(InternalSolver):
idx = [0]
else:
name = m.group(1)
idx = tuple(int(k) if k.isdecimal() else k
for k in m.group(2).split(","))
idx = tuple(
int(k) if k.isdecimal() else k for k in m.group(2).split(",")
)
if len(idx) == 1:
idx = idx[0]
if name not in self._all_vars:
self._all_vars[name] = {}
self._all_vars[name][idx] = var
if var.vtype != 'C':
if var.vtype != "C":
if name not in self._bin_vars:
self._bin_vars[name] = {}
self._bin_vars[name][idx] = var
@@ -103,15 +105,9 @@ class GurobiSolver(InternalSolver):
for (idx, var) in vardict.items():
var.vtype = self.GRB.BINARY
log = streams[0].getvalue()
return {
"Optimal value": self.model.objVal,
"Log": log
}
return {"Optimal value": self.model.objVal, "Log": log}
def solve(self,
tee=False,
iteration_cb=None,
lazy_cb=None):
def solve(self, tee=False, iteration_cb=None, lazy_cb=None):
self._raise_if_callback()
def cb_wrapper(cb_model, cb_where):
@@ -133,7 +129,7 @@ class GurobiSolver(InternalSolver):
if tee:
streams += [sys.stdout]
if iteration_cb is None:
iteration_cb = lambda : False
iteration_cb = lambda: False
while True:
logger.debug("Solving MIP...")
with RedirectOutput(streams):
@@ -187,7 +183,9 @@ class GurobiSolver(InternalSolver):
elif self.cb_where is None:
return var.x
else:
raise Exception("get_value cannot be called from cb_where=%s" % self.cb_where)
raise Exception(
"get_value cannot be called from cb_where=%s" % self.cb_where
)
def get_variables(self):
self._raise_if_callback()
@@ -220,8 +218,10 @@ class GurobiSolver(InternalSolver):
if value is not None:
count_fixed += 1
self._all_vars[varname][idx].start = value
logger.info("Setting start values for %d variables (out of %d)" %
(count_fixed, count_total))
logger.info(
"Setting start values for %d variables (out of %d)"
% (count_fixed, count_total)
)
def clear_warm_start(self):
self._raise_if_callback()
@@ -248,10 +248,7 @@ class GurobiSolver(InternalSolver):
def extract_constraint(self, cid):
self._raise_if_callback()
constr = self.model.getConstrByName(cid)
cobj = (self.model.getRow(constr),
constr.sense,
constr.RHS,
constr.ConstrName)
cobj = (self.model.getRow(constr), constr.sense, constr.RHS, constr.ConstrName)
self.model.remove(constr)
return cobj
@@ -316,7 +313,7 @@ class GurobiSolver(InternalSolver):
value = matches[0]
return value
def __getstate__(self):
def __getstate__(self):
return {
"params": self.params,
"lazy_cb_where": self.lazy_cb_where,
@@ -324,6 +321,7 @@ class GurobiSolver(InternalSolver):
def __setstate__(self, state):
from gurobipy import GRB
self.params = state["params"]
self.lazy_cb_where = state["lazy_cb_where"]
self.GRB = GRB
@@ -331,4 +329,4 @@ class GurobiSolver(InternalSolver):
self.model = None
self._all_vars = None
self._bin_vars = None
self.cb_where = None
self.cb_where = None