Add types to InternalSolver

This commit is contained in:
2021-01-20 10:07:28 -06:00
parent 69a82172b9
commit 1971389a57
10 changed files with 267 additions and 174 deletions

View File

@@ -1,14 +1,22 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
import re
import sys
import logging
from io import StringIO
from random import randint
from typing import List, Any, Dict, Union
from . import RedirectOutput
from .internal import InternalSolver
from .internal import (
InternalSolver,
LPSolveStats,
IterationCallback,
LazyCallback,
MIPSolveStats,
)
from .. import Instance
logger = logging.getLogger(__name__)
@@ -35,13 +43,14 @@ class GurobiSolver(InternalSolver):
if params is None:
params = {}
params["InfUnbdInfo"] = True
from gurobipy import GRB
import gurobipy
self.GRB = GRB
self.gp = gurobipy
self.GRB = gurobipy.GRB
self.instance = None
self.model = None
self.params = params
self._all_vars = None
self._all_vars: Dict = {}
self._bin_vars = None
self.cb_where = None
assert lazy_cb_frequency in [1, 2]
@@ -50,10 +59,15 @@ class GurobiSolver(InternalSolver):
else:
self.lazy_cb_where = [self.GRB.Callback.MIPSOL, self.GRB.Callback.MIPNODE]
def set_instance(self, instance, model=None):
def set_instance(
self,
instance: Instance,
model: Any = None,
) -> None:
self._raise_if_callback()
if model is None:
model = instance.to_model()
assert isinstance(model, self.gp.Model)
self.instance = instance
self.model = model
self.model.update()
@@ -67,7 +81,7 @@ class GurobiSolver(InternalSolver):
self._all_vars = {}
self._bin_vars = {}
for var in self.model.getVars():
m = re.search(r"([^[]*)\[(.*)\]", var.varName)
m = re.search(r"([^[]*)\[(.*)]", var.varName)
if m is None:
name = var.varName
idx = [0]
@@ -93,9 +107,12 @@ class GurobiSolver(InternalSolver):
if "seed" not in [k.lower() for k in self.params.keys()]:
self.model.setParam("Seed", randint(0, 1_000_000))
def solve_lp(self, tee=False):
def solve_lp(
self,
tee: bool = False,
) -> LPSolveStats:
self._raise_if_callback()
streams = [StringIO()]
streams: List[Any] = [StringIO()]
if tee:
streams += [sys.stdout]
self._apply_params(streams)
@@ -110,9 +127,17 @@ class GurobiSolver(InternalSolver):
for (idx, var) in vardict.items():
var.vtype = self.GRB.BINARY
log = streams[0].getvalue()
return {"Optimal value": self.model.objVal, "Log": log}
return {
"Optimal value": self.model.objVal,
"Log": log,
}
def solve(self, tee=False, iteration_cb=None, lazy_cb=None):
def solve(
self,
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
) -> MIPSolveStats:
self._raise_if_callback()
def cb_wrapper(cb_model, cb_where):
@@ -129,7 +154,7 @@ class GurobiSolver(InternalSolver):
self.params["LazyConstraints"] = 1
total_wallclock_time = 0
total_nodes = 0
streams = [StringIO()]
streams: List[Any] = [StringIO()]
if tee:
streams += [sys.stdout]
self._apply_params(streams)
@@ -155,15 +180,42 @@ class GurobiSolver(InternalSolver):
sense = "max"
lb = self.model.objVal
ub = self.model.objBound
return {
stats: MIPSolveStats = {
"Lower bound": lb,
"Upper bound": ub,
"Wallclock time": total_wallclock_time,
"Nodes": total_nodes,
"Sense": sense,
"Log": log,
"Warm start value": self._extract_warm_start_value(log),
}
ws_value = self._extract_warm_start_value(log)
if ws_value is not None:
stats["Warm start value"] = ws_value
return stats
def get_solution(self) -> Dict:
self._raise_if_callback()
solution: Dict = {}
for (varname, vardict) in self._all_vars.items():
solution[varname] = {}
for (idx, var) in vardict.items():
solution[varname][idx] = var.x
return solution
def set_warm_start(self, solution: Dict) -> None:
self._raise_if_callback()
self._clear_warm_start()
count_fixed, count_total = 0, 0
for (varname, vardict) in solution.items():
for (idx, value) in vardict.items():
count_total += 1
if value is not None:
count_fixed += 1
self._all_vars[varname][idx].start = value
logger.info(
"Setting start values for %d variables (out of %d)"
% (count_fixed, count_total)
)
def get_sense(self):
if self.model.modelSense == 1:
@@ -171,16 +223,6 @@ class GurobiSolver(InternalSolver):
else:
return "max"
def get_solution(self):
self._raise_if_callback()
solution = {}
for (varname, vardict) in self._all_vars.items():
solution[varname] = {}
for (idx, var) in vardict.items():
solution[varname][idx] = var.x
return solution
def get_value(self, var_name, index):
var = self._all_vars[var_name][index]
return self._get_value(var)
@@ -229,25 +271,10 @@ class GurobiSolver(InternalSolver):
else:
self.model.addConstr(constraint, name=name)
def set_warm_start(self, solution):
self._raise_if_callback()
count_fixed, count_total = 0, 0
for (varname, vardict) in solution.items():
for (idx, value) in vardict.items():
count_total += 1
if value is not None:
count_fixed += 1
self._all_vars[varname][idx].start = value
logger.info(
"Setting start values for %d variables (out of %d)"
% (count_fixed, count_total)
)
def clear_warm_start(self):
self._raise_if_callback()
for (varname, vardict) in self._all_vars:
def _clear_warm_start(self):
for (varname, vardict) in self._all_vars.items():
for (idx, var) in vardict.items():
var[idx].start = self.GRB.UNDEFINED
var.start = self.GRB.UNDEFINED
def fix(self, solution):
self._raise_if_callback()
@@ -311,17 +338,14 @@ class GurobiSolver(InternalSolver):
self.model = self.model.relax()
self._update_vars()
def set_branching_priorities(self, priorities):
self._raise_if_callback()
logger.warning("set_branching_priorities not implemented")
def _extract_warm_start_value(self, log):
ws = self.__extract(log, "MIP start with objective ([0-9.e+-]*)")
if ws is not None:
ws = float(ws)
return ws
def __extract(self, log, regexp, default=None):
@staticmethod
def __extract(log, regexp, default=None):
value = default
for line in log.splitlines():
matches = re.findall(regexp, line)