mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Add types to InternalSolver
This commit is contained in:
@@ -6,13 +6,20 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Any, List, Dict
|
||||
|
||||
import pyomo
|
||||
from pyomo import environ as pe
|
||||
from pyomo.core import Var, Constraint
|
||||
|
||||
from .. import RedirectOutput
|
||||
from ..internal import InternalSolver
|
||||
from ..internal import (
|
||||
InternalSolver,
|
||||
LPSolveStats,
|
||||
IterationCallback,
|
||||
LazyCallback,
|
||||
MIPSolveStats,
|
||||
)
|
||||
from ...instance import Instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -40,23 +47,74 @@ class BasePyomoSolver(InternalSolver):
|
||||
for (key, value) in params.items():
|
||||
self._pyomo_solver.options[key] = value
|
||||
|
||||
def solve_lp(self, tee=False):
|
||||
def solve_lp(
|
||||
self,
|
||||
tee: bool = False,
|
||||
) -> LPSolveStats:
|
||||
for var in self._bin_vars:
|
||||
lb, ub = var.bounds
|
||||
var.setlb(lb)
|
||||
var.setub(ub)
|
||||
var.domain = pyomo.core.base.set_types.Reals
|
||||
self._pyomo_solver.update_var(var)
|
||||
results = self._pyomo_solver.solve(tee=tee)
|
||||
streams: List[Any] = [StringIO()]
|
||||
if tee:
|
||||
streams += [sys.stdout]
|
||||
with RedirectOutput(streams):
|
||||
results = self._pyomo_solver.solve(tee=True)
|
||||
for var in self._bin_vars:
|
||||
var.domain = pyomo.core.base.set_types.Binary
|
||||
self._pyomo_solver.update_var(var)
|
||||
return {
|
||||
"Optimal value": results["Problem"][0]["Lower bound"],
|
||||
"Log": streams[0].getvalue(),
|
||||
}
|
||||
|
||||
def get_solution(self):
|
||||
solution = {}
|
||||
def solve(
|
||||
self,
|
||||
tee: bool = False,
|
||||
iteration_cb: IterationCallback = None,
|
||||
lazy_cb: LazyCallback = None,
|
||||
) -> MIPSolveStats:
|
||||
if lazy_cb is not None:
|
||||
raise Exception("lazy callback not supported")
|
||||
total_wallclock_time = 0
|
||||
streams: List[Any] = [StringIO()]
|
||||
if tee:
|
||||
streams += [sys.stdout]
|
||||
if iteration_cb is None:
|
||||
iteration_cb = lambda: False
|
||||
self.instance.found_violated_lazy_constraints = []
|
||||
self.instance.found_violated_user_cuts = []
|
||||
while True:
|
||||
logger.debug("Solving MIP...")
|
||||
with RedirectOutput(streams):
|
||||
results = self._pyomo_solver.solve(
|
||||
tee=True,
|
||||
warmstart=self._is_warm_start_available,
|
||||
)
|
||||
total_wallclock_time += results["Solver"][0]["Wallclock time"]
|
||||
should_repeat = iteration_cb()
|
||||
if not should_repeat:
|
||||
break
|
||||
log = streams[0].getvalue()
|
||||
stats: MIPSolveStats = {
|
||||
"Lower bound": results["Problem"][0]["Lower bound"],
|
||||
"Upper bound": results["Problem"][0]["Upper bound"],
|
||||
"Wallclock time": total_wallclock_time,
|
||||
"Sense": self._obj_sense,
|
||||
"Log": log,
|
||||
}
|
||||
node_count = self._extract_node_count(log)
|
||||
ws_value = self._extract_warm_start_value(log)
|
||||
if node_count is not None:
|
||||
stats["Nodes"] = node_count
|
||||
if ws_value is not None:
|
||||
stats["Warm start value"] = ws_value
|
||||
return stats
|
||||
|
||||
def get_solution(self) -> Dict:
|
||||
solution: Dict = {}
|
||||
for var in self.model.component_objects(Var):
|
||||
solution[str(var)] = {}
|
||||
for index in var:
|
||||
@@ -65,22 +123,8 @@ class BasePyomoSolver(InternalSolver):
|
||||
solution[str(var)][index] = var[index].value
|
||||
return solution
|
||||
|
||||
def get_value(self, var_name, index):
|
||||
var = self._varname_to_var[var_name]
|
||||
return var[index].value
|
||||
|
||||
def get_variables(self):
|
||||
variables = {}
|
||||
for var in self.model.component_objects(Var):
|
||||
variables[str(var)] = []
|
||||
for index in var:
|
||||
if var[index].fixed:
|
||||
continue
|
||||
variables[str(var)] += [index]
|
||||
return variables
|
||||
|
||||
def set_warm_start(self, solution):
|
||||
self.clear_warm_start()
|
||||
def set_warm_start(self, solution: Dict) -> None:
|
||||
self._clear_warm_start()
|
||||
count_total, count_fixed = 0, 0
|
||||
for var_name in solution:
|
||||
var = self._varname_to_var[var_name]
|
||||
@@ -96,16 +140,13 @@ class BasePyomoSolver(InternalSolver):
|
||||
% (count_fixed, count_total)
|
||||
)
|
||||
|
||||
def clear_warm_start(self):
|
||||
for var in self._all_vars:
|
||||
if not var.fixed:
|
||||
var.value = None
|
||||
self._is_warm_start_available = False
|
||||
|
||||
def set_instance(self, instance, model=None):
|
||||
def set_instance(
|
||||
self,
|
||||
instance: Instance,
|
||||
model: Any = None,
|
||||
) -> None:
|
||||
if model is None:
|
||||
model = instance.to_model()
|
||||
assert isinstance(instance, Instance)
|
||||
assert isinstance(model, pe.ConcreteModel)
|
||||
self.instance = instance
|
||||
self.model = model
|
||||
@@ -114,6 +155,26 @@ class BasePyomoSolver(InternalSolver):
|
||||
self._update_vars()
|
||||
self._update_constrs()
|
||||
|
||||
def get_value(self, var_name, index):
|
||||
var = self._varname_to_var[var_name]
|
||||
return var[index].value
|
||||
|
||||
def get_variables(self):
|
||||
variables = {}
|
||||
for var in self.model.component_objects(Var):
|
||||
variables[str(var)] = []
|
||||
for index in var:
|
||||
if var[index].fixed:
|
||||
continue
|
||||
variables[str(var)] += [index]
|
||||
return variables
|
||||
|
||||
def _clear_warm_start(self):
|
||||
for var in self._all_vars:
|
||||
if not var.fixed:
|
||||
var.value = None
|
||||
self._is_warm_start_available = False
|
||||
|
||||
def _update_obj(self):
|
||||
self._obj_sense = "max"
|
||||
if self._pyomo_solver._objective.sense == pyomo.core.kernel.objective.minimize:
|
||||
@@ -158,46 +219,6 @@ class BasePyomoSolver(InternalSolver):
|
||||
self._pyomo_solver.add_constraint(constraint)
|
||||
self._update_constrs()
|
||||
|
||||
def solve(self, tee=False, iteration_cb=None, lazy_cb=None):
|
||||
if lazy_cb is not None:
|
||||
raise Exception("lazy callback not supported")
|
||||
total_wallclock_time = 0
|
||||
streams = [StringIO()]
|
||||
if tee:
|
||||
streams += [sys.stdout]
|
||||
if iteration_cb is None:
|
||||
iteration_cb = lambda: False
|
||||
self.instance.found_violated_lazy_constraints = []
|
||||
self.instance.found_violated_user_cuts = []
|
||||
while True:
|
||||
logger.debug("Solving MIP...")
|
||||
with RedirectOutput(streams):
|
||||
results = self._pyomo_solver.solve(
|
||||
tee=True,
|
||||
warmstart=self._is_warm_start_available,
|
||||
)
|
||||
total_wallclock_time += results["Solver"][0]["Wallclock time"]
|
||||
should_repeat = iteration_cb()
|
||||
if not should_repeat:
|
||||
break
|
||||
log = streams[0].getvalue()
|
||||
stats = {
|
||||
"Lower bound": results["Problem"][0]["Lower bound"],
|
||||
"Upper bound": results["Problem"][0]["Upper bound"],
|
||||
"Wallclock time": total_wallclock_time,
|
||||
"Sense": self._obj_sense,
|
||||
"Log": log,
|
||||
}
|
||||
node_count = self._extract_node_count(log)
|
||||
if node_count is not None:
|
||||
stats["Nodes"] = node_count
|
||||
|
||||
ws_value = self._extract_warm_start_value(log)
|
||||
if ws_value is not None:
|
||||
stats["Warm start value"] = ws_value
|
||||
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def __extract(log, regexp, default=None):
|
||||
if regexp is None:
|
||||
@@ -257,6 +278,3 @@ class BasePyomoSolver(InternalSolver):
|
||||
|
||||
def get_sense(self):
|
||||
raise Exception("Not implemented")
|
||||
|
||||
def set_branching_priorities(self, priorities):
|
||||
raise Exception("Not supported")
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from io import StringIO
|
||||
|
||||
from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
from .base import BasePyomoSolver
|
||||
from .. import RedirectOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from io import StringIO
|
||||
|
||||
from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
from .base import BasePyomoSolver
|
||||
from .. import RedirectOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user