diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index d236039..82a204a 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -23,16 +23,22 @@ class BasePyomoSolver(InternalSolver): Base class for all Pyomo solvers. """ - def __init__(self): + def __init__( + self, + solver_factory, + params, + ): self.instance = None self.model = None self._all_vars = None self._bin_vars = None self._is_warm_start_available = False - self._pyomo_solver = None + self._pyomo_solver = solver_factory self._obj_sense = None self._varname_to_var = {} self._cname_to_constr = {} + for (key, value) in params.items(): + self._pyomo_solver.options[key] = value def solve_lp(self, tee=False): for var in self._bin_vars: @@ -244,3 +250,6 @@ class BasePyomoSolver(InternalSolver): def get_sense(self): raise Exception("Not implemented") + + def set_branching_priorities(self, priorities): + raise Exception("Not supported") diff --git a/miplearn/solvers/pyomo/cplex.py b/miplearn/solvers/pyomo/cplex.py index a1afc48..fd5cfa3 100644 --- a/miplearn/solvers/pyomo/cplex.py +++ b/miplearn/solvers/pyomo/cplex.py @@ -9,29 +9,27 @@ from .base import BasePyomoSolver class CplexPyomoSolver(BasePyomoSolver): - def __init__(self, options=None): - """ - Creates a new CPLEX solver, accessed through Pyomo. - - Parameters - ---------- - options: dict - Dictionary of options to pass to the Pyomo solver. For example, - {"mip_display": 5} to increase the log verbosity. - """ - super().__init__() - self._pyomo_solver = pe.SolverFactory("cplex_persistent") - self._pyomo_solver.options["randomseed"] = randint(low=0, high=1000).rvs() - self._pyomo_solver.options["mip_display"] = 4 - if options is not None: - for (key, value) in options.items(): - self._pyomo_solver.options[key] = value + """ + An InternalSolver that uses CPLEX and the Pyomo modeling language. + + Parameters + ---------- + params: dict + Dictionary of options to pass to the Pyomo solver. For example, + {"mip_display": 5} to increase the log verbosity. + """ + + def __init__(self, params=None): + super().__init__( + solver_factory=pe.SolverFactory("cplex_persistent"), + params={ + "randomseed": randint(low=0, high=1000).rvs(), + "mip_display": 4, + }, + ) def _get_warm_start_regexp(self): return "MIP start .* with objective ([0-9.e+-]*)\\." def _get_node_count_regexp(self): return "^[ *] *([0-9]+)" - - def set_branching_priorities(self, priorities): - raise NotImplementedError diff --git a/miplearn/solvers/pyomo/gurobi.py b/miplearn/solvers/pyomo/gurobi.py index e059a14..72b045d 100644 --- a/miplearn/solvers/pyomo/gurobi.py +++ b/miplearn/solvers/pyomo/gurobi.py @@ -15,22 +15,23 @@ logger = logging.getLogger(__name__) class GurobiPyomoSolver(BasePyomoSolver): - def __init__(self, options=None): - """ - Creates a new Gurobi solver, accessed through Pyomo. - - Parameters - ---------- - options: dict - Dictionary of options to pass to the Pyomo solver. For example, - {"Threads": 4} to set the number of threads. - """ - super().__init__() - self._pyomo_solver = pe.SolverFactory("gurobi_persistent") - self._pyomo_solver.options["Seed"] = randint(low=0, high=1000).rvs() - if options is not None: - for (key, value) in options.items(): - self._pyomo_solver.options[key] = value + """ + An InternalSolver that uses Gurobi and the Pyomo modeling language. + + Parameters + ---------- + params: dict + Dictionary of options to pass to the Pyomo solver. For example, + {"Threads": 4} to set the number of threads. + """ + + def __init__(self, params=None): + super().__init__( + solver_factory=pe.SolverFactory("gurobi_persistent"), + params={ + "Seed": randint(low=0, high=1000).rvs(), + }, + ) def _extract_node_count(self, log): return max(1, int(self._pyomo_solver._solver_model.getAttr("NodeCount")))