mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
LearningSolver: accept any callable function as solver
This commit is contained in:
@@ -2,13 +2,16 @@
|
|||||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsComponent
|
|
||||||
import pyomo.environ as pe
|
|
||||||
from pyomo.core import Var
|
|
||||||
from copy import deepcopy
|
|
||||||
from scipy.stats import randint
|
|
||||||
from p_tqdm import p_map
|
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pyomo.environ as pe
|
||||||
|
from p_tqdm import p_map
|
||||||
|
from pyomo.core import Var
|
||||||
|
from scipy.stats import randint
|
||||||
|
|
||||||
|
from . import ObjectiveValueComponent, PrimalSolutionComponent, LazyConstraintsComponent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -274,8 +277,9 @@ class LearningSolver:
|
|||||||
solver = CPLEXSolver()
|
solver = CPLEXSolver()
|
||||||
elif self.internal_solver_factory == "gurobi":
|
elif self.internal_solver_factory == "gurobi":
|
||||||
solver = GurobiSolver()
|
solver = GurobiSolver()
|
||||||
elif issubclass(self.internal_solver_factory, InternalSolver):
|
elif callable(self.internal_solver_factory):
|
||||||
solver = self.internal_solver_factory()
|
solver = self.internal_solver_factory()
|
||||||
|
assert isinstance(solver, InternalSolver)
|
||||||
else:
|
else:
|
||||||
raise Exception("solver %s not supported" % self.internal_solver_factory)
|
raise Exception("solver %s not supported" % self.internal_solver_factory)
|
||||||
solver.set_threads(self.threads)
|
solver.set_threads(self.threads)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
from miplearn import LearningSolver, BranchPriorityComponent
|
from miplearn import LearningSolver, BranchPriorityComponent
|
||||||
from miplearn.problems.knapsack import KnapsackInstance
|
from miplearn.problems.knapsack import KnapsackInstance
|
||||||
import pickle, tempfile
|
from miplearn.solvers import GurobiSolver
|
||||||
|
|
||||||
|
|
||||||
def _get_instance():
|
def _get_instance():
|
||||||
@@ -18,7 +18,7 @@ def _get_instance():
|
|||||||
def test_solver():
|
def test_solver():
|
||||||
instance = _get_instance()
|
instance = _get_instance()
|
||||||
for mode in ["exact", "heuristic"]:
|
for mode in ["exact", "heuristic"]:
|
||||||
for internal_solver in ["cplex", "gurobi"]:
|
for internal_solver in ["cplex", "gurobi", GurobiSolver]:
|
||||||
solver = LearningSolver(time_limit=300,
|
solver = LearningSolver(time_limit=300,
|
||||||
gap_tolerance=1e-3,
|
gap_tolerance=1e-3,
|
||||||
threads=1,
|
threads=1,
|
||||||
|
|||||||
Reference in New Issue
Block a user