Fix mypy errors

This commit is contained in:
2023-10-26 13:39:57 -05:00
parent e555dffc0c
commit 2d07a44f7d
12 changed files with 61 additions and 41 deletions

View File

@@ -14,6 +14,7 @@ from miplearn.problems.setcover import (
SetCoverGenerator,
build_setcover_model_pyomo,
)
from miplearn.solvers.abstract import AbstractModel
def test_set_cover_generator() -> None:
@@ -84,6 +85,7 @@ def test_set_cover() -> None:
build_setcover_model_pyomo(data),
build_setcover_model_gurobipy(data),
]:
assert isinstance(model, AbstractModel)
with NamedTemporaryFile() as tempfile:
with H5File(tempfile.name) as h5:
model.optimize()

View File

@@ -12,6 +12,7 @@ from miplearn.problems.stab import (
build_stab_model_pyomo,
build_stab_model_gurobipy,
)
from miplearn.solvers.abstract import AbstractModel
def test_stab() -> None:
@@ -23,6 +24,7 @@ def test_stab() -> None:
build_stab_model_pyomo(data),
build_stab_model_gurobipy(data),
]:
assert isinstance(model, AbstractModel)
with NamedTemporaryFile() as tempfile:
with H5File(tempfile.name) as h5:
model.optimize()

View File

@@ -3,6 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from tempfile import NamedTemporaryFile
from typing import Callable, Any
import numpy as np
import pytest
@@ -40,28 +41,28 @@ def test_pyomo_persistent(data: SetCoverData) -> None:
_test_solver(lambda d: build_setcover_model_pyomo(d, "gurobi_persistent"), data)
def _test_solver(build_model, data):
def _test_solver(build_model: Callable, data: Any) -> None:
_test_extract(build_model(data))
_test_add_constr(build_model(data))
_test_fix_vars(build_model(data))
_test_infeasible(build_model(data))
def _test_extract(model):
def _test_extract(model: AbstractModel) -> None:
with NamedTemporaryFile() as tempfile:
with H5File(tempfile.name) as h5:
def test_scalar(key, expected_value):
def test_scalar(key: str, expected_value: Any) -> None:
actual_value = h5.get_scalar(key)
assert actual_value is not None
assert actual_value == expected_value
def test_array(key, expected_value):
def test_array(key: str, expected_value: Any) -> None:
actual_value = h5.get_array(key)
assert actual_value is not None
assert actual_value.tolist() == expected_value
def test_sparse(key, expected_value):
def test_sparse(key: str, expected_value: Any) -> None:
actual_value = h5.get_sparse(key)
assert actual_value is not None
assert actual_value.todense().tolist() == expected_value
@@ -143,7 +144,7 @@ def _test_extract(model):
assert pool_var_values.shape == (n_sols, 5)
def _test_add_constr(model: AbstractModel):
def _test_add_constr(model: AbstractModel) -> None:
with NamedTemporaryFile() as tempfile:
with H5File(tempfile.name) as h5:
model.add_constrs(
@@ -154,10 +155,12 @@ def _test_add_constr(model: AbstractModel):
)
model.optimize()
model.extract_after_mip(h5)
assert h5.get_array("mip_var_values").tolist() == [1, 0, 0, 0, 1]
mip_var_values = h5.get_array("mip_var_values")
assert mip_var_values is not None
assert mip_var_values.tolist() == [1, 0, 0, 0, 1]
def _test_fix_vars(model: AbstractModel):
def _test_fix_vars(model: AbstractModel) -> None:
with NamedTemporaryFile() as tempfile:
with H5File(tempfile.name) as h5:
model.fix_variables(
@@ -166,10 +169,12 @@ def _test_fix_vars(model: AbstractModel):
)
model.optimize()
model.extract_after_mip(h5)
assert h5.get_array("mip_var_values").tolist() == [1, 0, 0, 0, 1]
mip_var_values = h5.get_array("mip_var_values")
assert mip_var_values is not None
assert mip_var_values.tolist() == [1, 0, 0, 0, 1]
def _test_infeasible(model: AbstractModel):
def _test_infeasible(model: AbstractModel) -> None:
with NamedTemporaryFile() as tempfile:
with H5File(tempfile.name) as h5:
model.fix_variables(