Make internal_solvers into a fixture

master
Alinson S. Xavier 5 years ago
parent 31d0a0861d
commit f3fd1e0cda
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -1,3 +1,18 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from typing import List
import pytest
from miplearn import InternalSolver, GurobiPyomoSolver, GurobiSolver
from miplearn.solvers.pyomo.xpress import XpressPyomoSolver
@pytest.fixture
def internal_solvers() -> List[InternalSolver]:
return [
GurobiPyomoSolver(),
GurobiSolver(),
XpressPyomoSolver(),
]

@ -3,15 +3,12 @@
# Released under the modified BSD license. See COPYING.md for more details.
from inspect import isclass
from typing import List, Callable, Any
from typing import Any
from miplearn.instance.base import Instance
from miplearn.problems.knapsack import KnapsackInstance, GurobiKnapsackInstance
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import InternalSolver
from miplearn.solvers.pyomo.base import BasePyomoSolver
from miplearn.solvers.pyomo.gurobi import GurobiPyomoSolver
from miplearn.solvers.pyomo.xpress import XpressPyomoSolver
def _is_subclass_or_instance(obj: Any, parent_class: Any) -> bool:
@ -34,11 +31,3 @@ def _get_knapsack_instance(solver: Any) -> Instance:
capacity=67.0,
)
assert False
def get_internal_solvers() -> List[InternalSolver]:
return [
GurobiPyomoSolver(),
GurobiSolver(),
XpressPyomoSolver(),
]

@ -4,17 +4,19 @@
import logging
from io import StringIO
from typing import List
from warnings import warn
import pyomo.environ as pe
from miplearn import InternalSolver
from miplearn.solvers import _RedirectOutput
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.pyomo.base import BasePyomoSolver
from . import (
_get_knapsack_instance,
get_internal_solvers,
)
from . import _get_knapsack_instance
# noinspection PyUnresolvedReferences
from .. import internal_solvers
from ..fixtures.infeasible import get_infeasible_instance
logger = logging.getLogger(__name__)
@ -31,8 +33,10 @@ def test_redirect_output() -> None:
assert io.getvalue() == "Hello world\n"
def test_internal_solver_warm_starts() -> None:
for solver in get_internal_solvers():
def test_internal_solver_warm_starts(
internal_solvers: List[InternalSolver],
) -> None:
for solver in internal_solvers:
logger.info("Solver: %s" % solver)
instance = _get_knapsack_instance(solver)
model = instance.to_model()
@ -54,8 +58,10 @@ def test_internal_solver_warm_starts() -> None:
assert stats["Upper bound"] == 725.0
def test_internal_solver() -> None:
for solver in get_internal_solvers():
def test_internal_solver(
internal_solvers: List[InternalSolver],
) -> None:
for solver in internal_solvers:
logger.info("Solver: %s" % solver)
instance = _get_knapsack_instance(solver)
@ -159,8 +165,10 @@ def test_internal_solver() -> None:
assert round(solver.get_dual("eq_capacity")) == 0.0
def test_relax() -> None:
for solver in get_internal_solvers():
def test_relax(
internal_solvers: List[InternalSolver],
) -> None:
for solver in internal_solvers:
instance = _get_knapsack_instance(solver)
solver.set_instance(instance)
solver.relax()
@ -169,8 +177,10 @@ def test_relax() -> None:
assert round(stats["Lower bound"]) == 1288.0
def test_infeasible_instance() -> None:
for solver in get_internal_solvers():
def test_infeasible_instance(
internal_solvers: List[InternalSolver],
) -> None:
for solver in internal_solvers:
instance = get_infeasible_instance(solver)
solver.set_instance(instance)
mip_stats = solver.solve()
@ -185,8 +195,10 @@ def test_infeasible_instance() -> None:
assert lp_stats["LP value"] is None
def test_iteration_cb() -> None:
for solver in get_internal_solvers():
def test_iteration_cb(
internal_solvers: List[InternalSolver],
) -> None:
for solver in internal_solvers:
logger.info("Solver: %s" % solver)
instance = _get_knapsack_instance(solver)
solver.set_instance(instance)

@ -9,18 +9,23 @@ from typing import List, cast
import dill
from miplearn import Instance
from miplearn import Instance, InternalSolver
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.learning import LearningSolver
from . import _get_knapsack_instance, get_internal_solvers
from . import _get_knapsack_instance
# noinspection PyUnresolvedReferences
from tests import internal_solvers
logger = logging.getLogger(__name__)
def test_learning_solver() -> None:
def test_learning_solver(
internal_solvers: List[InternalSolver],
) -> None:
for mode in ["exact", "heuristic"]:
for internal_solver in get_internal_solvers():
for internal_solver in internal_solvers:
logger.info("Solver: %s" % internal_solver)
instance = _get_knapsack_instance(internal_solver)
solver = LearningSolver(
@ -61,8 +66,10 @@ def test_learning_solver() -> None:
dill.dump(solver, file)
def test_solve_without_lp() -> None:
for internal_solver in get_internal_solvers():
def test_solve_without_lp(
internal_solvers: List[InternalSolver],
) -> None:
for internal_solver in internal_solvers:
logger.info("Solver: %s" % internal_solver)
instance = _get_knapsack_instance(internal_solver)
solver = LearningSolver(
@ -74,8 +81,10 @@ def test_solve_without_lp() -> None:
solver.solve(instance)
def test_parallel_solve() -> None:
for internal_solver in get_internal_solvers():
def test_parallel_solve(
internal_solvers: List[InternalSolver],
) -> None:
for internal_solver in internal_solvers:
instances = [_get_knapsack_instance(internal_solver) for _ in range(10)]
solver = LearningSolver(solver=internal_solver)
results = solver.parallel_solve(instances, n_jobs=3)
@ -86,8 +95,10 @@ def test_parallel_solve() -> None:
assert len(data.solution.keys()) == 4
def test_solve_fit_from_disk() -> None:
for internal_solver in get_internal_solvers():
def test_solve_fit_from_disk(
internal_solvers: List[InternalSolver],
) -> None:
for internal_solver in internal_solvers:
# Create instances and pickle them
instances: List[Instance] = []
for k in range(3):

Loading…
Cancel
Save