From b35f411199fec9ef84024d1f932abd98c5cbffc1 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Tue, 7 Apr 2020 09:35:02 -0500 Subject: [PATCH] LearningSolver: accept any callable function as solver --- src/python/miplearn/solvers.py | 14 +++++++++----- src/python/miplearn/tests/test_solver.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/python/miplearn/solvers.py b/src/python/miplearn/solvers.py index 6f0e406..50b7082 100644 --- a/src/python/miplearn/solvers.py +++ b/src/python/miplearn/solvers.py @@ -2,13 +2,16 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsComponent +import logging +from copy import deepcopy + import pyomo.environ as pe +from p_tqdm import p_map from pyomo.core import Var -from copy import deepcopy from scipy.stats import randint -from p_tqdm import p_map -import logging + +from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsComponent + logger = logging.getLogger(__name__) @@ -274,8 +277,9 @@ class LearningSolver: solver = CPLEXSolver() elif self.internal_solver_factory == "gurobi": solver = GurobiSolver() - elif issubclass(self.internal_solver_factory, InternalSolver): + elif callable(self.internal_solver_factory): solver = self.internal_solver_factory() + assert isinstance(solver, InternalSolver) else: raise Exception("solver %s not supported" % self.internal_solver_factory) solver.set_threads(self.threads) diff --git a/src/python/miplearn/tests/test_solver.py b/src/python/miplearn/tests/test_solver.py index b61e543..5f9846a 100644 --- a/src/python/miplearn/tests/test_solver.py +++ b/src/python/miplearn/tests/test_solver.py @@ -4,7 +4,7 @@ from miplearn import LearningSolver, BranchPriorityComponent from miplearn.problems.knapsack import KnapsackInstance -import pickle, tempfile +from miplearn.solvers import GurobiSolver def _get_instance(): @@ -18,7 +18,7 @@ def _get_instance(): def test_solver(): instance = _get_instance() for mode in ["exact", "heuristic"]: - for internal_solver in ["cplex", "gurobi"]: + for internal_solver in ["cplex", "gurobi", GurobiSolver]: solver = LearningSolver(time_limit=300, gap_tolerance=1e-3, threads=1,