Make internal_solvers into a fixture

This commit is contained in:
2021-04-09 18:35:01 -05:00
parent 31d0a0861d
commit f3fd1e0cda
4 changed files with 63 additions and 36 deletions

View File

@@ -9,18 +9,23 @@ from typing import List, cast
import dill
from miplearn import Instance
from miplearn import Instance, InternalSolver
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.learning import LearningSolver
from . import _get_knapsack_instance, get_internal_solvers
from . import _get_knapsack_instance
# noinspection PyUnresolvedReferences
from tests import internal_solvers
logger = logging.getLogger(__name__)
def test_learning_solver() -> None:
def test_learning_solver(
internal_solvers: List[InternalSolver],
) -> None:
for mode in ["exact", "heuristic"]:
for internal_solver in get_internal_solvers():
for internal_solver in internal_solvers:
logger.info("Solver: %s" % internal_solver)
instance = _get_knapsack_instance(internal_solver)
solver = LearningSolver(
@@ -61,8 +66,10 @@ def test_learning_solver() -> None:
dill.dump(solver, file)
def test_solve_without_lp() -> None:
for internal_solver in get_internal_solvers():
def test_solve_without_lp(
internal_solvers: List[InternalSolver],
) -> None:
for internal_solver in internal_solvers:
logger.info("Solver: %s" % internal_solver)
instance = _get_knapsack_instance(internal_solver)
solver = LearningSolver(
@@ -74,8 +81,10 @@ def test_solve_without_lp() -> None:
solver.solve(instance)
def test_parallel_solve() -> None:
for internal_solver in get_internal_solvers():
def test_parallel_solve(
internal_solvers: List[InternalSolver],
) -> None:
for internal_solver in internal_solvers:
instances = [_get_knapsack_instance(internal_solver) for _ in range(10)]
solver = LearningSolver(solver=internal_solver)
results = solver.parallel_solve(instances, n_jobs=3)
@@ -86,8 +95,10 @@ def test_parallel_solve() -> None:
assert len(data.solution.keys()) == 4
def test_solve_fit_from_disk() -> None:
for internal_solver in get_internal_solvers():
def test_solve_fit_from_disk(
internal_solvers: List[InternalSolver],
) -> None:
for internal_solver in internal_solvers:
# Create instances and pickle them
instances: List[Instance] = []
for k in range(3):