Add types to InternalSolver

master
Alinson S. Xavier 5 years ago
parent 69a82172b9
commit 1971389a57

@ -23,6 +23,6 @@ jobs:
python -m pip install -i https://pypi.gurobi.com gurobipy python -m pip install -i https://pypi.gurobi.com gurobipy
pip install -r requirements.txt pip install -r requirements.txt
- name: Test with pytest - name: Test
run: | run: |
pytest make test

@ -0,0 +1,2 @@
[mypy]
ignore_missing_imports = True

@ -1,6 +1,7 @@
PYTHON := python3 PYTHON := python3
PYTEST := pytest PYTEST := pytest
PIP := $(PYTHON) -m pip PIP := $(PYTHON) -m pip
MYPY := $(PYTHON) -m mypy
PYTEST_ARGS := -W ignore::DeprecationWarning -vv -x --log-level=DEBUG PYTEST_ARGS := -W ignore::DeprecationWarning -vv -x --log-level=DEBUG
VERSION := 0.2 VERSION := 0.2
@ -38,6 +39,7 @@ reformat:
$(PYTHON) -m black . $(PYTHON) -m black .
test: test:
$(MYPY) -p miplearn
$(PYTEST) $(PYTEST_ARGS) $(PYTEST) $(PYTEST_ARGS)
.PHONY: test test-watch docs install .PHONY: test test-watch docs install

@ -1,14 +1,22 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging
import re import re
import sys import sys
import logging
from io import StringIO from io import StringIO
from random import randint from random import randint
from typing import List, Any, Dict, Union
from . import RedirectOutput from . import RedirectOutput
from .internal import InternalSolver from .internal import (
InternalSolver,
LPSolveStats,
IterationCallback,
LazyCallback,
MIPSolveStats,
)
from .. import Instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,13 +43,14 @@ class GurobiSolver(InternalSolver):
if params is None: if params is None:
params = {} params = {}
params["InfUnbdInfo"] = True params["InfUnbdInfo"] = True
from gurobipy import GRB import gurobipy
self.GRB = GRB self.gp = gurobipy
self.GRB = gurobipy.GRB
self.instance = None self.instance = None
self.model = None self.model = None
self.params = params self.params = params
self._all_vars = None self._all_vars: Dict = {}
self._bin_vars = None self._bin_vars = None
self.cb_where = None self.cb_where = None
assert lazy_cb_frequency in [1, 2] assert lazy_cb_frequency in [1, 2]
@ -50,10 +59,15 @@ class GurobiSolver(InternalSolver):
else: else:
self.lazy_cb_where = [self.GRB.Callback.MIPSOL, self.GRB.Callback.MIPNODE] self.lazy_cb_where = [self.GRB.Callback.MIPSOL, self.GRB.Callback.MIPNODE]
def set_instance(self, instance, model=None): def set_instance(
self,
instance: Instance,
model: Any = None,
) -> None:
self._raise_if_callback() self._raise_if_callback()
if model is None: if model is None:
model = instance.to_model() model = instance.to_model()
assert isinstance(model, self.gp.Model)
self.instance = instance self.instance = instance
self.model = model self.model = model
self.model.update() self.model.update()
@ -67,7 +81,7 @@ class GurobiSolver(InternalSolver):
self._all_vars = {} self._all_vars = {}
self._bin_vars = {} self._bin_vars = {}
for var in self.model.getVars(): for var in self.model.getVars():
m = re.search(r"([^[]*)\[(.*)\]", var.varName) m = re.search(r"([^[]*)\[(.*)]", var.varName)
if m is None: if m is None:
name = var.varName name = var.varName
idx = [0] idx = [0]
@ -93,9 +107,12 @@ class GurobiSolver(InternalSolver):
if "seed" not in [k.lower() for k in self.params.keys()]: if "seed" not in [k.lower() for k in self.params.keys()]:
self.model.setParam("Seed", randint(0, 1_000_000)) self.model.setParam("Seed", randint(0, 1_000_000))
def solve_lp(self, tee=False): def solve_lp(
self,
tee: bool = False,
) -> LPSolveStats:
self._raise_if_callback() self._raise_if_callback()
streams = [StringIO()] streams: List[Any] = [StringIO()]
if tee: if tee:
streams += [sys.stdout] streams += [sys.stdout]
self._apply_params(streams) self._apply_params(streams)
@ -110,9 +127,17 @@ class GurobiSolver(InternalSolver):
for (idx, var) in vardict.items(): for (idx, var) in vardict.items():
var.vtype = self.GRB.BINARY var.vtype = self.GRB.BINARY
log = streams[0].getvalue() log = streams[0].getvalue()
return {"Optimal value": self.model.objVal, "Log": log} return {
"Optimal value": self.model.objVal,
"Log": log,
}
def solve(self, tee=False, iteration_cb=None, lazy_cb=None): def solve(
self,
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
) -> MIPSolveStats:
self._raise_if_callback() self._raise_if_callback()
def cb_wrapper(cb_model, cb_where): def cb_wrapper(cb_model, cb_where):
@ -129,7 +154,7 @@ class GurobiSolver(InternalSolver):
self.params["LazyConstraints"] = 1 self.params["LazyConstraints"] = 1
total_wallclock_time = 0 total_wallclock_time = 0
total_nodes = 0 total_nodes = 0
streams = [StringIO()] streams: List[Any] = [StringIO()]
if tee: if tee:
streams += [sys.stdout] streams += [sys.stdout]
self._apply_params(streams) self._apply_params(streams)
@ -155,32 +180,49 @@ class GurobiSolver(InternalSolver):
sense = "max" sense = "max"
lb = self.model.objVal lb = self.model.objVal
ub = self.model.objBound ub = self.model.objBound
return { stats: MIPSolveStats = {
"Lower bound": lb, "Lower bound": lb,
"Upper bound": ub, "Upper bound": ub,
"Wallclock time": total_wallclock_time, "Wallclock time": total_wallclock_time,
"Nodes": total_nodes, "Nodes": total_nodes,
"Sense": sense, "Sense": sense,
"Log": log, "Log": log,
"Warm start value": self._extract_warm_start_value(log),
} }
ws_value = self._extract_warm_start_value(log)
if ws_value is not None:
stats["Warm start value"] = ws_value
return stats
def get_sense(self): def get_solution(self) -> Dict:
if self.model.modelSense == 1:
return "min"
else:
return "max"
def get_solution(self):
self._raise_if_callback() self._raise_if_callback()
solution: Dict = {}
solution = {}
for (varname, vardict) in self._all_vars.items(): for (varname, vardict) in self._all_vars.items():
solution[varname] = {} solution[varname] = {}
for (idx, var) in vardict.items(): for (idx, var) in vardict.items():
solution[varname][idx] = var.x solution[varname][idx] = var.x
return solution return solution
def set_warm_start(self, solution: Dict) -> None:
self._raise_if_callback()
self._clear_warm_start()
count_fixed, count_total = 0, 0
for (varname, vardict) in solution.items():
for (idx, value) in vardict.items():
count_total += 1
if value is not None:
count_fixed += 1
self._all_vars[varname][idx].start = value
logger.info(
"Setting start values for %d variables (out of %d)"
% (count_fixed, count_total)
)
def get_sense(self):
if self.model.modelSense == 1:
return "min"
else:
return "max"
def get_value(self, var_name, index): def get_value(self, var_name, index):
var = self._all_vars[var_name][index] var = self._all_vars[var_name][index]
return self._get_value(var) return self._get_value(var)
@ -229,25 +271,10 @@ class GurobiSolver(InternalSolver):
else: else:
self.model.addConstr(constraint, name=name) self.model.addConstr(constraint, name=name)
def set_warm_start(self, solution): def _clear_warm_start(self):
self._raise_if_callback() for (varname, vardict) in self._all_vars.items():
count_fixed, count_total = 0, 0
for (varname, vardict) in solution.items():
for (idx, value) in vardict.items():
count_total += 1
if value is not None:
count_fixed += 1
self._all_vars[varname][idx].start = value
logger.info(
"Setting start values for %d variables (out of %d)"
% (count_fixed, count_total)
)
def clear_warm_start(self):
self._raise_if_callback()
for (varname, vardict) in self._all_vars:
for (idx, var) in vardict.items(): for (idx, var) in vardict.items():
var[idx].start = self.GRB.UNDEFINED var.start = self.GRB.UNDEFINED
def fix(self, solution): def fix(self, solution):
self._raise_if_callback() self._raise_if_callback()
@ -311,17 +338,14 @@ class GurobiSolver(InternalSolver):
self.model = self.model.relax() self.model = self.model.relax()
self._update_vars() self._update_vars()
def set_branching_priorities(self, priorities):
self._raise_if_callback()
logger.warning("set_branching_priorities not implemented")
def _extract_warm_start_value(self, log): def _extract_warm_start_value(self, log):
ws = self.__extract(log, "MIP start with objective ([0-9.e+-]*)") ws = self.__extract(log, "MIP start with objective ([0-9.e+-]*)")
if ws is not None: if ws is not None:
ws = float(ws) ws = float(ws)
return ws return ws
def __extract(self, log, regexp, default=None): @staticmethod
def __extract(log, regexp, default=None):
value = default value = default
for line in log.splitlines(): for line in log.splitlines():
matches = re.findall(regexp, line) matches = re.findall(regexp, line)

@ -4,6 +4,9 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TypedDict, Callable, Any, Dict, List
from ..instance import Instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -12,13 +15,47 @@ class ExtractedConstraint(ABC):
pass pass
class Constraint:
pass
LPSolveStats = TypedDict(
"LPSolveStats",
{
"Optimal value": float,
"Log": str,
},
)
MIPSolveStats = TypedDict(
"MIPSolveStats",
{
"Lower bound": float,
"Upper bound": float,
"Wallclock time": float,
"Nodes": float,
"Sense": str,
"Log": str,
"Warm start value": float,
},
total=False,
)
IterationCallback = Callable[[], bool]
LazyCallback = Callable[[Any, Any], None]
class InternalSolver(ABC): class InternalSolver(ABC):
""" """
Abstract class representing the MIP solver used internally by LearningSolver. Abstract class representing the MIP solver used internally by LearningSolver.
""" """
@abstractmethod @abstractmethod
def solve_lp(self, tee=False): def solve_lp(
self,
tee: bool = False,
) -> LPSolveStats:
""" """
Solves the LP relaxation of the currently loaded instance. After this Solves the LP relaxation of the currently loaded instance. After this
method finishes, the solution can be retrieved by calling `get_solution`. method finishes, the solution can be retrieved by calling `get_solution`.
@ -31,13 +68,17 @@ class InternalSolver(ABC):
Returns Returns
------- -------
dict dict
A dictionary of solver statistics containing the following keys: A dictionary of solver statistics.
"Optimal value".
""" """
pass pass
@abstractmethod @abstractmethod
def solve(self, tee=False, iteration_cb=None, lazy_cb=None): def solve(
self,
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
) -> MIPSolveStats:
""" """
Solves the currently loaded instance. After this method finishes, Solves the currently loaded instance. After this method finishes,
the best solution found can be retrieved by calling `get_solution`. the best solution found can be retrieved by calling `get_solution`.
@ -71,7 +112,7 @@ class InternalSolver(ABC):
pass pass
@abstractmethod @abstractmethod
def get_solution(self): def get_solution(self) -> Dict:
""" """
Returns current solution found by the solver. Returns current solution found by the solver.
@ -85,7 +126,7 @@ class InternalSolver(ABC):
pass pass
@abstractmethod @abstractmethod
def set_warm_start(self, solution): def set_warm_start(self, solution: Dict) -> None:
""" """
Sets the warm start to be used by the solver. Sets the warm start to be used by the solver.
@ -97,7 +138,11 @@ class InternalSolver(ABC):
pass pass
@abstractmethod @abstractmethod
def set_instance(self, instance, model=None): def set_instance(
self,
instance: Instance,
model: Any = None,
) -> None:
""" """
Loads the given instance into the solver. Loads the given instance into the solver.
@ -113,7 +158,7 @@ class InternalSolver(ABC):
pass pass
@abstractmethod @abstractmethod
def fix(self, solution): def fix(self, solution: Dict) -> None:
""" """
Fixes the values of a subset of decision variables. Fixes the values of a subset of decision variables.
@ -123,8 +168,7 @@ class InternalSolver(ABC):
""" """
pass pass
@abstractmethod def set_branching_priorities(self, priorities: Dict) -> None:
def set_branching_priorities(self, priorities):
""" """
Sets the branching priorities for the given decision variables. Sets the branching priorities for the given decision variables.
@ -136,36 +180,55 @@ class InternalSolver(ABC):
`get_solution`. Missing values indicate variables whose priorities `get_solution`. Missing values indicate variables whose priorities
should not be modified. should not be modified.
""" """
raise NotImplementedError()
@abstractmethod
def get_constraint_ids(self) -> List[str]:
"""
Returns a list of ids which uniquely identify each constraint in the model.
"""
pass pass
@abstractmethod @abstractmethod
def add_constraint(self, constraint): def add_constraint(self, cobj: Constraint):
""" """
Adds a single constraint to the model. Adds a single constraint to the model.
""" """
pass pass
@abstractmethod @abstractmethod
def get_value(self, var_name, index): def extract_constraint(self, cid: str) -> Constraint:
""" """
Returns the current value of a decision variable. Removes a given constraint from the model and returns an object `cobj` which
can be used to verify if the removed constraint is still satisfied by
the current solution, using `is_constraint_satisfied(cobj)`, and can potentially
be re-added to the model using `add_constraint(cobj)`.
""" """
pass pass
@abstractmethod @abstractmethod
def get_constraint_ids(self): def is_constraint_satisfied(self, cobj: Constraint):
""" """
Returns a list of ids, which uniquely identify each constraint in the model. Returns True if the current solution satisfies the given constraint.
""" """
pass pass
@abstractmethod @abstractmethod
def extract_constraint(self, cid): def set_constraint_sense(self, cid: str, sense: str) -> None:
pass
@abstractmethod
def get_constraint_sense(self, cid: str) -> str:
pass
@abstractmethod
def set_constraint_rhs(self, cid: str, rhs: str) -> None:
pass
@abstractmethod
def get_value(self, var_name, index):
""" """
Removes a given constraint from the model and returns an object `cobj` which Returns the current value of a decision variable.
can be used to verify if the removed constraint is still satisfied by
the current solution, using `is_constraint_satisfied(cobj)`, and can potentially
be re-added to the model using `add_constraint(cobj)`.
""" """
pass pass
@ -210,23 +273,6 @@ class InternalSolver(ABC):
""" """
pass pass
@abstractmethod
def is_constraint_satisfied(self, cobj):
"""Returns True if the current solution satisfies the given constraint."""
pass
@abstractmethod
def set_constraint_sense(self, cid, sense):
pass
@abstractmethod
def get_constraint_sense(self, cid):
pass
@abstractmethod
def set_constraint_rhs(self, cid, rhs):
pass
@abstractmethod @abstractmethod
def get_variables(self): def get_variables(self):
pass pass

@ -6,13 +6,20 @@ import logging
import re import re
import sys import sys
from io import StringIO from io import StringIO
from typing import Any, List, Dict
import pyomo import pyomo
from pyomo import environ as pe from pyomo import environ as pe
from pyomo.core import Var, Constraint from pyomo.core import Var, Constraint
from .. import RedirectOutput from .. import RedirectOutput
from ..internal import InternalSolver from ..internal import (
InternalSolver,
LPSolveStats,
IterationCallback,
LazyCallback,
MIPSolveStats,
)
from ...instance import Instance from ...instance import Instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,23 +47,74 @@ class BasePyomoSolver(InternalSolver):
for (key, value) in params.items(): for (key, value) in params.items():
self._pyomo_solver.options[key] = value self._pyomo_solver.options[key] = value
def solve_lp(self, tee=False): def solve_lp(
self,
tee: bool = False,
) -> LPSolveStats:
for var in self._bin_vars: for var in self._bin_vars:
lb, ub = var.bounds lb, ub = var.bounds
var.setlb(lb) var.setlb(lb)
var.setub(ub) var.setub(ub)
var.domain = pyomo.core.base.set_types.Reals var.domain = pyomo.core.base.set_types.Reals
self._pyomo_solver.update_var(var) self._pyomo_solver.update_var(var)
results = self._pyomo_solver.solve(tee=tee) streams: List[Any] = [StringIO()]
if tee:
streams += [sys.stdout]
with RedirectOutput(streams):
results = self._pyomo_solver.solve(tee=True)
for var in self._bin_vars: for var in self._bin_vars:
var.domain = pyomo.core.base.set_types.Binary var.domain = pyomo.core.base.set_types.Binary
self._pyomo_solver.update_var(var) self._pyomo_solver.update_var(var)
return { return {
"Optimal value": results["Problem"][0]["Lower bound"], "Optimal value": results["Problem"][0]["Lower bound"],
"Log": streams[0].getvalue(),
}
def solve(
self,
tee: bool = False,
iteration_cb: IterationCallback = None,
lazy_cb: LazyCallback = None,
) -> MIPSolveStats:
if lazy_cb is not None:
raise Exception("lazy callback not supported")
total_wallclock_time = 0
streams: List[Any] = [StringIO()]
if tee:
streams += [sys.stdout]
if iteration_cb is None:
iteration_cb = lambda: False
self.instance.found_violated_lazy_constraints = []
self.instance.found_violated_user_cuts = []
while True:
logger.debug("Solving MIP...")
with RedirectOutput(streams):
results = self._pyomo_solver.solve(
tee=True,
warmstart=self._is_warm_start_available,
)
total_wallclock_time += results["Solver"][0]["Wallclock time"]
should_repeat = iteration_cb()
if not should_repeat:
break
log = streams[0].getvalue()
stats: MIPSolveStats = {
"Lower bound": results["Problem"][0]["Lower bound"],
"Upper bound": results["Problem"][0]["Upper bound"],
"Wallclock time": total_wallclock_time,
"Sense": self._obj_sense,
"Log": log,
} }
node_count = self._extract_node_count(log)
ws_value = self._extract_warm_start_value(log)
if node_count is not None:
stats["Nodes"] = node_count
if ws_value is not None:
stats["Warm start value"] = ws_value
return stats
def get_solution(self): def get_solution(self) -> Dict:
solution = {} solution: Dict = {}
for var in self.model.component_objects(Var): for var in self.model.component_objects(Var):
solution[str(var)] = {} solution[str(var)] = {}
for index in var: for index in var:
@ -65,22 +123,8 @@ class BasePyomoSolver(InternalSolver):
solution[str(var)][index] = var[index].value solution[str(var)][index] = var[index].value
return solution return solution
def get_value(self, var_name, index): def set_warm_start(self, solution: Dict) -> None:
var = self._varname_to_var[var_name] self._clear_warm_start()
return var[index].value
def get_variables(self):
variables = {}
for var in self.model.component_objects(Var):
variables[str(var)] = []
for index in var:
if var[index].fixed:
continue
variables[str(var)] += [index]
return variables
def set_warm_start(self, solution):
self.clear_warm_start()
count_total, count_fixed = 0, 0 count_total, count_fixed = 0, 0
for var_name in solution: for var_name in solution:
var = self._varname_to_var[var_name] var = self._varname_to_var[var_name]
@ -96,16 +140,13 @@ class BasePyomoSolver(InternalSolver):
% (count_fixed, count_total) % (count_fixed, count_total)
) )
def clear_warm_start(self): def set_instance(
for var in self._all_vars: self,
if not var.fixed: instance: Instance,
var.value = None model: Any = None,
self._is_warm_start_available = False ) -> None:
def set_instance(self, instance, model=None):
if model is None: if model is None:
model = instance.to_model() model = instance.to_model()
assert isinstance(instance, Instance)
assert isinstance(model, pe.ConcreteModel) assert isinstance(model, pe.ConcreteModel)
self.instance = instance self.instance = instance
self.model = model self.model = model
@ -114,6 +155,26 @@ class BasePyomoSolver(InternalSolver):
self._update_vars() self._update_vars()
self._update_constrs() self._update_constrs()
def get_value(self, var_name, index):
var = self._varname_to_var[var_name]
return var[index].value
def get_variables(self):
variables = {}
for var in self.model.component_objects(Var):
variables[str(var)] = []
for index in var:
if var[index].fixed:
continue
variables[str(var)] += [index]
return variables
def _clear_warm_start(self):
for var in self._all_vars:
if not var.fixed:
var.value = None
self._is_warm_start_available = False
def _update_obj(self): def _update_obj(self):
self._obj_sense = "max" self._obj_sense = "max"
if self._pyomo_solver._objective.sense == pyomo.core.kernel.objective.minimize: if self._pyomo_solver._objective.sense == pyomo.core.kernel.objective.minimize:
@ -158,46 +219,6 @@ class BasePyomoSolver(InternalSolver):
self._pyomo_solver.add_constraint(constraint) self._pyomo_solver.add_constraint(constraint)
self._update_constrs() self._update_constrs()
def solve(self, tee=False, iteration_cb=None, lazy_cb=None):
if lazy_cb is not None:
raise Exception("lazy callback not supported")
total_wallclock_time = 0
streams = [StringIO()]
if tee:
streams += [sys.stdout]
if iteration_cb is None:
iteration_cb = lambda: False
self.instance.found_violated_lazy_constraints = []
self.instance.found_violated_user_cuts = []
while True:
logger.debug("Solving MIP...")
with RedirectOutput(streams):
results = self._pyomo_solver.solve(
tee=True,
warmstart=self._is_warm_start_available,
)
total_wallclock_time += results["Solver"][0]["Wallclock time"]
should_repeat = iteration_cb()
if not should_repeat:
break
log = streams[0].getvalue()
stats = {
"Lower bound": results["Problem"][0]["Lower bound"],
"Upper bound": results["Problem"][0]["Upper bound"],
"Wallclock time": total_wallclock_time,
"Sense": self._obj_sense,
"Log": log,
}
node_count = self._extract_node_count(log)
if node_count is not None:
stats["Nodes"] = node_count
ws_value = self._extract_warm_start_value(log)
if ws_value is not None:
stats["Warm start value"] = ws_value
return stats
@staticmethod @staticmethod
def __extract(log, regexp, default=None): def __extract(log, regexp, default=None):
if regexp is None: if regexp is None:
@ -257,6 +278,3 @@ class BasePyomoSolver(InternalSolver):
def get_sense(self): def get_sense(self):
raise Exception("Not implemented") raise Exception("Not implemented")
def set_branching_priorities(self, priorities):
raise Exception("Not supported")

@ -2,14 +2,12 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import sys
import logging import logging
from io import StringIO
from pyomo import environ as pe from pyomo import environ as pe
from scipy.stats import randint from scipy.stats import randint
from .base import BasePyomoSolver from .base import BasePyomoSolver
from .. import RedirectOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -2,14 +2,12 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import sys
import logging import logging
from io import StringIO
from pyomo import environ as pe from pyomo import environ as pe
from scipy.stats import randint from scipy.stats import randint
from .base import BasePyomoSolver from .base import BasePyomoSolver
from .. import RedirectOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -3,7 +3,9 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from inspect import isclass from inspect import isclass
from miplearn import BasePyomoSolver, GurobiSolver, GurobiPyomoSolver from typing import List, Callable
from miplearn import BasePyomoSolver, GurobiSolver, GurobiPyomoSolver, InternalSolver
from miplearn.problems.knapsack import KnapsackInstance, GurobiKnapsackInstance from miplearn.problems.knapsack import KnapsackInstance, GurobiKnapsackInstance
from miplearn.solvers.pyomo.xpress import XpressPyomoSolver from miplearn.solvers.pyomo.xpress import XpressPyomoSolver
@ -31,5 +33,5 @@ def _get_instance(solver):
assert False assert False
def _get_internal_solvers(): def _get_internal_solvers() -> List[Callable[[], InternalSolver]]:
return [GurobiPyomoSolver, GurobiSolver, XpressPyomoSolver] return [GurobiPyomoSolver, GurobiSolver, XpressPyomoSolver]

@ -4,6 +4,7 @@
import logging import logging
from io import StringIO from io import StringIO
from warnings import warn
import pyomo.environ as pe import pyomo.environ as pe
@ -45,6 +46,8 @@ def test_internal_solver_warm_starts():
stats = solver.solve(tee=True) stats = solver.solve(tee=True)
if "Warm start value" in stats: if "Warm start value" in stats:
assert stats["Warm start value"] == 725.0 assert stats["Warm start value"] == 725.0
else:
warn(f"{solver_class.__name__} should set warm start value")
solver.set_warm_start( solver.set_warm_start(
{ {
@ -57,8 +60,7 @@ def test_internal_solver_warm_starts():
} }
) )
stats = solver.solve(tee=True) stats = solver.solve(tee=True)
if "Warm start value" in stats: assert "Warm start value" not in stats
assert stats["Warm start value"] is None
solver.fix( solver.fix(
{ {
@ -86,6 +88,7 @@ def test_internal_solver():
stats = solver.solve_lp() stats = solver.solve_lp()
assert round(stats["Optimal value"], 3) == 1287.923 assert round(stats["Optimal value"], 3) == 1287.923
assert len(stats["Log"]) > 100
solution = solver.get_solution() solution = solver.get_solution()
assert round(solution["x"][0], 3) == 1.000 assert round(solution["x"][0], 3) == 1.000

Loading…
Cancel
Save