mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make InternalSolver clonable
This commit is contained in:
@@ -62,7 +62,7 @@ def train(args):
|
||||
PickleGzInstance(f) for f in glob.glob(f"{basepath}/train/*.gz")
|
||||
]
|
||||
solver = LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
solver=GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": int(args["--train-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
@@ -83,7 +83,7 @@ def test_baseline(args):
|
||||
if not os.path.isfile(csv_filename):
|
||||
solvers = {
|
||||
"baseline": LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
solver=GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": int(args["--test-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
@@ -107,7 +107,7 @@ def test_ml(args):
|
||||
if not os.path.isfile(csv_filename):
|
||||
solvers = {
|
||||
"ml-exact": LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
solver=GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": int(args["--test-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
@@ -115,7 +115,7 @@ def test_ml(args):
|
||||
),
|
||||
),
|
||||
"ml-heuristic": LearningSolver(
|
||||
solver=lambda: GurobiPyomoSolver(
|
||||
solver=GurobiPyomoSolver(
|
||||
params={
|
||||
"TimeLimit": int(args["--test-time-limit"]),
|
||||
"Threads": int(args["--solver-threads"]),
|
||||
|
||||
@@ -51,9 +51,11 @@ class GurobiSolver(InternalSolver):
|
||||
) -> None:
|
||||
import gurobipy
|
||||
|
||||
assert lazy_cb_frequency in [1, 2]
|
||||
if params is None:
|
||||
params = {}
|
||||
params["InfUnbdInfo"] = True
|
||||
params["Seed"] = randint(0, 1_000_000)
|
||||
|
||||
self.gp = gurobipy
|
||||
self.instance: Optional[Instance] = None
|
||||
@@ -62,9 +64,9 @@ class GurobiSolver(InternalSolver):
|
||||
self.varname_to_var: Dict[str, "gurobipy.Var"] = {}
|
||||
self.bin_vars: List["gurobipy.Var"] = []
|
||||
self.cb_where: Optional[int] = None
|
||||
self.lazy_cb_frequency = lazy_cb_frequency
|
||||
|
||||
assert lazy_cb_frequency in [1, 2]
|
||||
if lazy_cb_frequency == 1:
|
||||
if self.lazy_cb_frequency == 1:
|
||||
self.lazy_cb_where = [self.gp.GRB.Callback.MIPSOL]
|
||||
else:
|
||||
self.lazy_cb_where = [
|
||||
@@ -113,8 +115,6 @@ class GurobiSolver(InternalSolver):
|
||||
with _RedirectOutput(streams):
|
||||
for (name, value) in self.params.items():
|
||||
self.model.setParam(name, value)
|
||||
if "seed" not in [k.lower() for k in self.params.keys()]:
|
||||
self.model.setParam("Seed", randint(0, 1_000_000))
|
||||
|
||||
@overrides
|
||||
def solve_lp(
|
||||
@@ -428,3 +428,10 @@ class GurobiSolver(InternalSolver):
|
||||
self.instance = None
|
||||
self.model = None
|
||||
self.cb_where = None
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "GurobiSolver":
|
||||
return GurobiSolver(
|
||||
params=self.params,
|
||||
lazy_cb_frequency=self.lazy_cb_frequency,
|
||||
)
|
||||
|
||||
@@ -284,3 +284,11 @@ class InternalSolver(ABC, EnforceOverrides):
|
||||
model before a solution is available.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "InternalSolver":
|
||||
"""
|
||||
Returns a new copy of this solver with identical parameters, but otherwise
|
||||
completely unitialized.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -91,20 +91,20 @@ class LearningSolver:
|
||||
self,
|
||||
components: List[Component] = None,
|
||||
mode: str = "exact",
|
||||
solver: Callable[[], InternalSolver] = None,
|
||||
solver: InternalSolver = None,
|
||||
use_lazy_cb: bool = False,
|
||||
solve_lp: bool = True,
|
||||
simulate_perfect: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
if solver is None:
|
||||
solver = GurobiPyomoSolver
|
||||
assert callable(solver), f"Callable expected. Found {solver.__class__} instead."
|
||||
solver = GurobiPyomoSolver()
|
||||
assert isinstance(solver, InternalSolver)
|
||||
self.components: Dict[str, Component] = {}
|
||||
self.internal_solver: Optional[InternalSolver] = None
|
||||
self.internal_solver_prototype: InternalSolver = solver
|
||||
self.mode: str = mode
|
||||
self.simulate_perfect: bool = simulate_perfect
|
||||
self.solve_lp: bool = solve_lp
|
||||
self.solver_factory: Callable[[], InternalSolver] = solver
|
||||
self.tee = False
|
||||
self.use_lazy_cb: bool = use_lazy_cb
|
||||
if components is not None:
|
||||
@@ -144,7 +144,7 @@ class LearningSolver:
|
||||
# Initialize internal solver
|
||||
# -------------------------------------------------------
|
||||
self.tee = tee
|
||||
self.internal_solver = self.solver_factory()
|
||||
self.internal_solver = self.internal_solver_prototype.clone()
|
||||
assert self.internal_solver is not None
|
||||
assert isinstance(self.internal_solver, InternalSolver)
|
||||
self.internal_solver.set_instance(instance, model)
|
||||
|
||||
@@ -46,6 +46,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
) -> None:
|
||||
self.instance: Optional[Instance] = None
|
||||
self.model: Optional[pe.ConcreteModel] = None
|
||||
self.params = params
|
||||
self._all_vars: List[pe.Var] = []
|
||||
self._bin_vars: List[pe.Var] = []
|
||||
self._is_warm_start_available: bool = False
|
||||
|
||||
@@ -28,7 +28,6 @@ class CplexPyomoSolver(BasePyomoSolver):
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "randomseed" not in params.keys():
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
if "mip_display" not in params.keys():
|
||||
params["mip_display"] = 4
|
||||
@@ -44,3 +43,7 @@ class CplexPyomoSolver(BasePyomoSolver):
|
||||
@overrides
|
||||
def _get_node_count_regexp(self):
|
||||
return "^[ *] *([0-9]+)"
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "CplexPyomoSolver":
|
||||
return CplexPyomoSolver(params=self.params)
|
||||
|
||||
@@ -32,7 +32,6 @@ class GurobiPyomoSolver(BasePyomoSolver):
|
||||
) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "seed" not in params.keys():
|
||||
params["seed"] = randint(low=0, high=1000).rvs()
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("gurobi_persistent"),
|
||||
@@ -61,3 +60,7 @@ class GurobiPyomoSolver(BasePyomoSolver):
|
||||
var = self._varname_to_var[varname]
|
||||
gvar = self._pyomo_solver._pyomo_var_to_solver_var_map[var]
|
||||
gvar.setAttr(GRB.Attr.BranchPriority, int(round(priority)))
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "GurobiPyomoSolver":
|
||||
return GurobiPyomoSolver(params=self.params)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from overrides import overrides
|
||||
from pyomo import environ as pe
|
||||
from scipy.stats import randint
|
||||
|
||||
@@ -27,9 +28,12 @@ class XpressPyomoSolver(BasePyomoSolver):
|
||||
def __init__(self, params: SolverParams = None) -> None:
|
||||
if params is None:
|
||||
params = {}
|
||||
if "randomseed" not in params.keys():
|
||||
params["randomseed"] = randint(low=0, high=1000).rvs()
|
||||
super().__init__(
|
||||
solver_factory=pe.SolverFactory("xpress_persistent"),
|
||||
params=params,
|
||||
)
|
||||
|
||||
@overrides
|
||||
def clone(self) -> "XpressPyomoSolver":
|
||||
return XpressPyomoSolver(params=self.params)
|
||||
|
||||
@@ -64,10 +64,8 @@ def stab_instance() -> Instance:
|
||||
@pytest.fixture
|
||||
def solver() -> LearningSolver:
|
||||
return LearningSolver(
|
||||
solver=lambda: GurobiSolver(),
|
||||
components=[
|
||||
UserCutsComponent(),
|
||||
],
|
||||
solver=GurobiSolver(),
|
||||
components=[UserCutsComponent()],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,5 +35,9 @@ def _get_knapsack_instance(solver):
|
||||
assert False
|
||||
|
||||
|
||||
def get_internal_solvers() -> List[Callable[[], InternalSolver]]:
|
||||
return [GurobiPyomoSolver, GurobiSolver, XpressPyomoSolver]
|
||||
def get_internal_solvers() -> List[InternalSolver]:
|
||||
return [
|
||||
GurobiPyomoSolver(),
|
||||
GurobiSolver(),
|
||||
XpressPyomoSolver(),
|
||||
]
|
||||
|
||||
@@ -32,18 +32,17 @@ def test_redirect_output():
|
||||
|
||||
|
||||
def test_internal_solver_warm_starts():
|
||||
for solver_class in get_internal_solvers():
|
||||
logger.info("Solver: %s" % solver_class)
|
||||
instance = _get_knapsack_instance(solver_class)
|
||||
for solver in get_internal_solvers():
|
||||
logger.info("Solver: %s" % solver)
|
||||
instance = _get_knapsack_instance(solver)
|
||||
model = instance.to_model()
|
||||
solver = solver_class()
|
||||
solver.set_instance(instance, model)
|
||||
solver.set_warm_start({"x[0]": 1.0, "x[1]": 0.0, "x[2]": 0.0, "x[3]": 1.0})
|
||||
stats = solver.solve(tee=True)
|
||||
if stats["Warm start value"] is not None:
|
||||
assert stats["Warm start value"] == 725.0
|
||||
else:
|
||||
warn(f"{solver_class.__name__} should set warm start value")
|
||||
warn(f"{solver.__class__.__name__} should set warm start value")
|
||||
|
||||
solver.set_warm_start({"x[0]": 1.0, "x[1]": 1.0, "x[2]": 1.0, "x[3]": 1.0})
|
||||
stats = solver.solve(tee=True)
|
||||
@@ -56,12 +55,11 @@ def test_internal_solver_warm_starts():
|
||||
|
||||
|
||||
def test_internal_solver():
|
||||
for solver_class in get_internal_solvers():
|
||||
logger.info("Solver: %s" % solver_class)
|
||||
for solver in get_internal_solvers():
|
||||
logger.info("Solver: %s" % solver)
|
||||
|
||||
instance = _get_knapsack_instance(solver_class)
|
||||
instance = _get_knapsack_instance(solver)
|
||||
model = instance.to_model()
|
||||
solver = solver_class()
|
||||
solver.set_instance(instance, model)
|
||||
|
||||
assert solver.get_variable_names() == ["x[0]", "x[1]", "x[2]", "x[3]"]
|
||||
@@ -150,9 +148,8 @@ def test_internal_solver():
|
||||
|
||||
|
||||
def test_relax():
|
||||
for solver_class in get_internal_solvers():
|
||||
instance = _get_knapsack_instance(solver_class)
|
||||
solver = solver_class()
|
||||
for solver in get_internal_solvers():
|
||||
instance = _get_knapsack_instance(solver)
|
||||
solver.set_instance(instance)
|
||||
solver.relax()
|
||||
stats = solver.solve()
|
||||
@@ -160,9 +157,8 @@ def test_relax():
|
||||
|
||||
|
||||
def test_infeasible_instance():
|
||||
for solver_class in get_internal_solvers():
|
||||
instance = get_infeasible_instance(solver_class)
|
||||
solver = solver_class()
|
||||
for solver in get_internal_solvers():
|
||||
instance = get_infeasible_instance(solver)
|
||||
solver.set_instance(instance)
|
||||
stats = solver.solve()
|
||||
|
||||
@@ -177,10 +173,9 @@ def test_infeasible_instance():
|
||||
|
||||
|
||||
def test_iteration_cb():
|
||||
for solver_class in get_internal_solvers():
|
||||
logger.info("Solver: %s" % solver_class)
|
||||
instance = _get_knapsack_instance(solver_class)
|
||||
solver = solver_class()
|
||||
for solver in get_internal_solvers():
|
||||
logger.info("Solver: %s" % solver)
|
||||
instance = _get_knapsack_instance(solver)
|
||||
solver.set_instance(instance)
|
||||
count = 0
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ def test_solve_fit_from_disk():
|
||||
|
||||
|
||||
def test_simulate_perfect():
|
||||
internal_solver = GurobiSolver
|
||||
internal_solver = GurobiSolver()
|
||||
instance = _get_knapsack_instance(internal_solver)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp:
|
||||
write_pickle_gz(instance, tmp.name)
|
||||
|
||||
Reference in New Issue
Block a user