Fix failing mypy tests

master
Alinson S. Xavier 5 years ago
parent 5116681291
commit 5aa434b439
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -14,7 +14,7 @@ from tests.solvers import _is_subclass_or_instance
def get_test_pyomo_instances() -> Tuple[List[Instance], List[Any]]: def get_test_pyomo_instances() -> Tuple[List[Instance], List[Any]]:
instances = [ instances: List[Instance] = [
KnapsackInstance( KnapsackInstance(
weights=[23.0, 26.0, 20.0, 18.0], weights=[23.0, 26.0, 20.0, 18.0],
prices=[505.0, 352.0, 458.0, 220.0], prices=[505.0, 352.0, 458.0, 220.0],

@ -5,9 +5,11 @@
import logging import logging
import os import os
import tempfile import tempfile
from typing import List, cast
import dill import dill
from miplearn import Instance
from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz from miplearn.instance.picklegz import PickleGzInstance, write_pickle_gz, read_pickle_gz
from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.learning import LearningSolver from miplearn.solvers.learning import LearningSolver
@ -87,7 +89,7 @@ def test_parallel_solve() -> None:
def test_solve_fit_from_disk() -> None: def test_solve_fit_from_disk() -> None:
for internal_solver in get_internal_solvers(): for internal_solver in get_internal_solvers():
# Create instances and pickle them # Create instances and pickle them
instances = [] instances: List[Instance] = []
for k in range(3): for k in range(3):
instance = _get_knapsack_instance(internal_solver) instance = _get_knapsack_instance(internal_solver)
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file: with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as file:
@ -97,7 +99,7 @@ def test_solve_fit_from_disk() -> None:
# Test: solve # Test: solve
solver = LearningSolver(solver=internal_solver) solver = LearningSolver(solver=internal_solver)
solver.solve(instances[0]) solver.solve(instances[0])
instance_loaded = read_pickle_gz(instances[0].filename) instance_loaded = read_pickle_gz(cast(PickleGzInstance, instances[0]).filename)
assert len(instance_loaded.training_data) > 0 assert len(instance_loaded.training_data) > 0
assert instance_loaded.features.instance is not None assert instance_loaded.features.instance is not None
assert instance_loaded.features.variables is not None assert instance_loaded.features.variables is not None
@ -106,7 +108,7 @@ def test_solve_fit_from_disk() -> None:
# Test: parallel_solve # Test: parallel_solve
solver.parallel_solve(instances) solver.parallel_solve(instances)
for instance in instances: for instance in instances:
instance_loaded = read_pickle_gz(instance.filename) instance_loaded = read_pickle_gz(cast(PickleGzInstance, instance).filename)
assert len(instance_loaded.training_data) > 0 assert len(instance_loaded.training_data) > 0
assert instance_loaded.features.instance is not None assert instance_loaded.features.instance is not None
assert instance_loaded.features.variables is not None assert instance_loaded.features.variables is not None
@ -114,7 +116,7 @@ def test_solve_fit_from_disk() -> None:
# Delete temporary files # Delete temporary files
for instance in instances: for instance in instances:
os.remove(instance.filename) os.remove(cast(PickleGzInstance, instance).filename)
def test_simulate_perfect() -> None: def test_simulate_perfect() -> None:

Loading…
Cancel
Save