LearningSolver: accept any callable function as solver

This commit is contained in:
2020-04-07 09:35:02 -05:00
parent 5bb109cfad
commit b35f411199
2 changed files with 13 additions and 9 deletions

View File

@@ -2,13 +2,16 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsComponent
import pyomo.environ as pe
from pyomo.core import Var
from copy import deepcopy
from scipy.stats import randint
from p_tqdm import p_map
import logging import logging
from copy import deepcopy
import pyomo.environ as pe
from p_tqdm import p_map
from pyomo.core import Var
from scipy.stats import randint
from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsComponent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -274,8 +277,9 @@ class LearningSolver:
solver = CPLEXSolver() solver = CPLEXSolver()
elif self.internal_solver_factory == "gurobi": elif self.internal_solver_factory == "gurobi":
solver = GurobiSolver() solver = GurobiSolver()
elif issubclass(self.internal_solver_factory, InternalSolver): elif callable(self.internal_solver_factory):
solver = self.internal_solver_factory() solver = self.internal_solver_factory()
assert isinstance(solver, InternalSolver)
else: else:
raise Exception("solver %s not supported" % self.internal_solver_factory) raise Exception("solver %s not supported" % self.internal_solver_factory)
solver.set_threads(self.threads) solver.set_threads(self.threads)

View File

@@ -4,7 +4,7 @@
from miplearn import LearningSolver, BranchPriorityComponent from miplearn import LearningSolver, BranchPriorityComponent
from miplearn.problems.knapsack import KnapsackInstance from miplearn.problems.knapsack import KnapsackInstance
import pickle, tempfile from miplearn.solvers import GurobiSolver
def _get_instance(): def _get_instance():
@@ -18,7 +18,7 @@ def _get_instance():
def test_solver(): def test_solver():
instance = _get_instance() instance = _get_instance()
for mode in ["exact", "heuristic"]: for mode in ["exact", "heuristic"]:
for internal_solver in ["cplex", "gurobi"]: for internal_solver in ["cplex", "gurobi", GurobiSolver]:
solver = LearningSolver(time_limit=300, solver = LearningSolver(time_limit=300,
gap_tolerance=1e-3, gap_tolerance=1e-3,
threads=1, threads=1,