mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make InternalSolver clonable
This commit is contained in:
@@ -51,9 +51,11 @@ class GurobiSolver(InternalSolver):
|
||||
) -> None:
|
||||
import gurobipy
|
||||
|
||||
assert lazy_cb_frequency in [1, 2]
|
||||
if params is None:
|
||||
params = {}
|
||||
params["InfUnbdInfo"] = True
|
||||
params["Seed"] = randint(0, 1_000_000)
|
||||
|
||||
self.gp = gurobipy
|
||||
self.instance: Optional[Instance] = None
|
||||
@@ -62,9 +64,9 @@ class GurobiSolver(InternalSolver):
|
||||
self.varname_to_var: Dict[str, "gurobipy.Var"] = {}
|
||||
self.bin_vars: List["gurobipy.Var"] = []
|
||||
self.cb_where: Optional[int] = None
|
||||
self.lazy_cb_frequency = lazy_cb_frequency
|
||||
|
||||
assert lazy_cb_frequency in [1, 2]
|
||||
if lazy_cb_frequency == 1:
|
||||
if self.lazy_cb_frequency == 1:
|
||||
self.lazy_cb_where = [self.gp.GRB.Callback.MIPSOL]
|
||||
else:
|
||||
self.lazy_cb_where = [
|
||||
@@ -113,8 +115,6 @@ class GurobiSolver(InternalSolver):
|
||||
with _RedirectOutput(streams):
|
||||
for (name, value) in self.params.items():
|
||||
self.model.setParam(name, value)
|
||||
if "seed" not in [k.lower() for k in self.params.keys()]:
|
||||
self.model.setParam("Seed", randint(0, 1_000_000))
|
||||
|
||||
@overrides
|
||||
def solve_lp(
|
||||
@@ -428,3 +428,10 @@ class GurobiSolver(InternalSolver):
|
||||
self.instance = None
|
||||
self.model = None
|
||||
self.cb_where = None
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "GurobiSolver":
|
||||
return GurobiSolver(
|
||||
params=self.params,
|
||||
lazy_cb_frequency=self.lazy_cb_frequency,
|
||||
)
|
||||
|
||||
@@ -284,3 +284,11 @@ class InternalSolver(ABC, EnforceOverrides):
|
||||
model before a solution is available.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "InternalSolver":
|
||||
"""
|
||||
Returns a new copy of this solver with identical parameters, but otherwise
|
||||
completely unitialized.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -91,20 +91,20 @@ class LearningSolver:
|
||||
self,
|
||||
components: List[Component] = None,
|
||||
mode: str = "exact",
|
||||
solver: Callable[[], InternalSolver] = None,
|
||||
solver: InternalSolver = None,
|
||||
use_lazy_cb: bool = False,
|
||||
solve_lp: bool = True,
|
||||
simulate_perfect: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if solver is None:
|
||||
solver = GurobiPyomoSolver
|
||||
assert callable(solver), f"Callable expected. Found {solver.__class__} instead."
|
||||
solver = GurobiPyomoSolver()
|
||||
assert isinstance(solver, InternalSolver)
|
||||
self.components: Dict[str, Component] = {}
|
||||
self.internal_solver: Optional[InternalSolver] = None
|
||||
self.internal_solver_prototype: InternalSolver = solver
|
||||
self.mode: str = mode
|
||||
self.simulate_perfect: bool = simulate_perfect
|
||||
self.solve_lp: bool = solve_lp
|
||||
self.solver_factory: Callable[[], InternalSolver] = solver
|
||||
self.tee = False
|
||||
self.use_lazy_cb: bool = use_lazy_cb
|
||||
if components is not None:
|
||||
@@ -144,7 +144,7 @@ class LearningSolver:
|
||||
# Initialize internal solver
|
||||
# -------------------------------------------------------
|
||||
self.tee = tee
|
||||
self.internal_solver = self.solver_factory()
|
||||
self.internal_solver = self.internal_solver_prototype.clone()
|
||||
assert self.internal_solver is not None
|
||||
assert isinstance(self.internal_solver, InternalSolver)
|
||||
self.internal_solver.set_instance(instance, model)
|
||||
|
||||
@@ -46,6 +46,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
) -> None:
|
||||
self.instance: Optional[Instance] = None
|
||||
self.model: Optional[pe.ConcreteModel] = None
|
||||
self.params = params
|
||||
self._all_vars: List[pe.Var] = []
|
||||
self._bin_vars: List[pe.Var] = []
|
||||
self._is_warm_start_available: bool = False
|
||||
|
||||
@@ -28,8 +28,7 @@ class CplexPyomoSolver(BasePyomoSolver):
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "randomseed" not in params.keys():
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
if "mip_display" not in params.keys():
|
||||
params["mip_display"] = 4
|
||||
super().__init__(
|
||||
@@ -44,3 +43,7 @@ class CplexPyomoSolver(BasePyomoSolver):
|
||||
@overrides
|
||||
def _get_node_count_regexp(self):
|
||||
return "^[ *] *([0-9]+)"
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "CplexPyomoSolver":
|
||||
return CplexPyomoSolver(params=self.params)
|
||||
|
||||
@@ -32,8 +32,7 @@ class GurobiPyomoSolver(BasePyomoSolver):
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "seed" not in params.keys():
|
||||
params["seed"] = randint(low=0, high=1000).rvs()
|
||||
params["seed"] = randint(low=0, high=1000).rvs()
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("gurobi_persistent"),
|
||||
params=params,
|
||||
@@ -61,3 +60,7 @@ class GurobiPyomoSolver(BasePyomoSolver):
|
||||
var = self._varname_to_var[varname]
|
||||
gvar = self._pyomo_solver._pyomo_var_to_solver_var_map[var]
|
||||
gvar.setAttr(GRB.Attr.BranchPriority, int(round(priority)))
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "GurobiPyomoSolver":
|
||||
return GurobiPyomoSolver(params=self.params)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from overrides import overrides
|
||||
from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
@@ -27,9 +28,12 @@ class XpressPyomoSolver(BasePyomoSolver):
|
||||
def __init__(self, params: SolverParams = None) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "randomseed" not in params.keys():
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("xpress_persistent"),
|
||||
params=params,
|
||||
)
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "XpressPyomoSolver":
|
||||
return XpressPyomoSolver(params=self.params)
|
||||
|
||||
Reference in New Issue
Block a user