Fix mypy errors

This commit is contained in:
2023-10-26 13:39:57 -05:00
parent e555dffc0c
commit 2d07a44f7d
12 changed files with 61 additions and 41 deletions

View File

@@ -9,9 +9,10 @@ import numpy as np
from scipy.sparse import lil_matrix
from miplearn.h5 import H5File
from miplearn.solvers.abstract import AbstractModel
class GurobiModel:
class GurobiModel(AbstractModel):
_supports_basis_status = True
_supports_sensitivity_analysis = True
_supports_node_count = True

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import List, Any, Union
from typing import List, Any, Union, Dict, Callable, Optional
from miplearn.h5 import H5File
from miplearn.io import _to_h5_filename
@@ -11,23 +11,28 @@ from miplearn.solvers.abstract import AbstractModel
class LearningSolver:
def __init__(self, components: List[Any], skip_lp=False):
def __init__(self, components: List[Any], skip_lp: bool = False) -> None:
self.components = components
self.skip_lp = skip_lp
def fit(self, data_filenames):
def fit(self, data_filenames: List[str]) -> None:
h5_filenames = [_to_h5_filename(f) for f in data_filenames]
for comp in self.components:
comp.fit(h5_filenames)
def optimize(self, model: Union[str, AbstractModel], build_model=None):
def optimize(
self,
model: Union[str, AbstractModel],
build_model: Optional[Callable] = None,
) -> Dict[str, Any]:
if isinstance(model, str):
h5_filename = _to_h5_filename(model)
assert build_model is not None
model = build_model(model)
assert isinstance(model, AbstractModel)
else:
h5_filename = NamedTemporaryFile().name
stats = {}
stats: Dict[str, Any] = {}
mode = "r+" if exists(h5_filename) else "w"
with H5File(h5_filename, mode) as h5:
model.extract_after_load(h5)

View File

@@ -2,7 +2,7 @@
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
from numbers import Number
from typing import Optional, Dict, List, Any
from typing import Optional, Dict, List, Any, Tuple, Union
import numpy as np
import pyomo
@@ -24,7 +24,7 @@ class PyomoModel(AbstractModel):
self.is_persistent = hasattr(self.solver, "set_instance")
if self.is_persistent:
self.solver.set_instance(model)
self.results = None
self.results: Optional[Dict] = None
self._is_warm_start_available = False
if not hasattr(self.inner, "dual"):
self.inner.dual = Suffix(direction=Suffix.IMPORT)
@@ -56,7 +56,7 @@ class PyomoModel(AbstractModel):
raise Exception(f"Unknown sense: {sense}")
self.solver.add_constraint(eq)
def _var_names_to_vars(self, var_names):
def _var_names_to_vars(self, var_names: np.ndarray) -> List[Any]:
varname_to_var = {}
for var in self.inner.component_objects(Var):
for idx in var:
@@ -70,12 +70,14 @@ class PyomoModel(AbstractModel):
h5.put_scalar("static_sense", self._get_sense())
def extract_after_lp(self, h5: H5File) -> None:
assert self.results is not None
self._extract_after_lp_vars(h5)
self._extract_after_lp_constrs(h5)
h5.put_scalar("lp_obj_value", self.results["Problem"][0]["Lower bound"])
h5.put_scalar("lp_wallclock_time", self._get_runtime())
def _get_runtime(self):
def _get_runtime(self) -> float:
assert self.results is not None
solver_dict = self.results["Solver"][0]
for key in ["Wallclock time", "User time"]:
if isinstance(solver_dict[key], Number):
@@ -83,6 +85,7 @@ class PyomoModel(AbstractModel):
raise Exception("Time unavailable")
def extract_after_mip(self, h5: H5File) -> None:
assert self.results is not None
h5.put_scalar("mip_wallclock_time", self._get_runtime())
if self.results["Solver"][0]["Termination condition"] == "infeasible":
return
@@ -150,7 +153,7 @@ class PyomoModel(AbstractModel):
var.value = val
self._is_warm_start_available = True
def _extract_after_load_vars(self, h5):
def _extract_after_load_vars(self, h5: H5File) -> None:
names: List[str] = []
types: List[str] = []
upper_bounds: List[float] = []
@@ -211,7 +214,7 @@ class PyomoModel(AbstractModel):
h5.put_array("static_var_obj_coeffs", np.array(obj_coeffs))
h5.put_scalar("static_obj_offset", obj_offset)
def _extract_after_load_constrs(self, h5):
def _extract_after_load_constrs(self, h5: H5File) -> None:
names: List[str] = []
rhs: List[float] = []
senses: List[str] = []
@@ -219,7 +222,7 @@ class PyomoModel(AbstractModel):
lhs_col: List[int] = []
lhs_data: List[float] = []
varname_to_idx = {}
varname_to_idx: Dict[str, int] = {}
for var in self.inner.component_objects(Var):
for idx in var:
varname = var.name
@@ -285,7 +288,7 @@ class PyomoModel(AbstractModel):
h5.put_array("static_constr_rhs", np.array(rhs))
h5.put_array("static_constr_sense", np.array(senses, dtype="S"))
def _extract_after_lp_vars(self, h5):
def _extract_after_lp_vars(self, h5: H5File) -> None:
rc = []
values = []
for var in self.inner.component_objects(Var):
@@ -296,7 +299,7 @@ class PyomoModel(AbstractModel):
h5.put_array("lp_var_reduced_costs", np.array(rc))
h5.put_array("lp_var_values", np.array(values))
def _extract_after_lp_constrs(self, h5):
def _extract_after_lp_constrs(self, h5: H5File) -> None:
dual = []
slacks = []
for constr in self.inner.component_objects(pyomo.core.Constraint):
@@ -307,7 +310,7 @@ class PyomoModel(AbstractModel):
h5.put_array("lp_constr_dual_values", np.array(dual))
h5.put_array("lp_constr_slacks", np.array(slacks))
def _extract_after_mip_vars(self, h5):
def _extract_after_mip_vars(self, h5: H5File) -> None:
values = []
for var in self.inner.component_objects(Var):
for idx in var:
@@ -315,7 +318,7 @@ class PyomoModel(AbstractModel):
values.append(v.value)
h5.put_array("mip_var_values", np.array(values))
def _extract_after_mip_constrs(self, h5):
def _extract_after_mip_constrs(self, h5: H5File) -> None:
slacks = []
for constr in self.inner.component_objects(pyomo.core.Constraint):
for idx in constr:
@@ -323,7 +326,7 @@ class PyomoModel(AbstractModel):
slacks.append(abs(self.inner.slack[c]))
h5.put_array("mip_constr_slacks", np.array(slacks))
def _parse_pyomo_expr(self, expr: Any):
def _parse_pyomo_expr(self, expr: Any) -> Tuple[Dict[str, float], float]:
lhs = {}
offset = 0.0
if isinstance(expr, SumExpression):
@@ -332,7 +335,7 @@ class PyomoModel(AbstractModel):
lhs[term._args_[1].name] = float(term._args_[0])
elif isinstance(term, _GeneralVarData):
lhs[term.name] = 1.0
elif isinstance(term, Number):
elif isinstance(term, float):
offset += term
else:
raise Exception(f"Unknown term type: {term.__class__.__name__}")
@@ -342,7 +345,7 @@ class PyomoModel(AbstractModel):
raise Exception(f"Unknown expression type: {expr.__class__.__name__}")
return lhs, offset
def _gap(self, zp, zd, tol=1e-6):
def _gap(self, zp: float, zd: float, tol: float = 1e-6) -> float:
# Reference: https://www.gurobi.com/documentation/9.5/refman/mipgap2.html
if abs(zp) < tol:
if abs(zd) < tol:
@@ -352,7 +355,7 @@ class PyomoModel(AbstractModel):
else:
return abs(zp - zd) / abs(zp)
def _get_sense(self):
def _get_sense(self) -> str:
for obj in self.inner.component_objects(Objective):
sense = obj.sense
if sense == pyomo.core.kernel.objective.minimize:
@@ -361,6 +364,7 @@ class PyomoModel(AbstractModel):
return "max"
else:
raise Exception(f"Unknown sense: ${sense}")
raise Exception(f"No objective")
def write(self, filename: str) -> None:
self.inner.write(filename, io_options={"symbolic_solver_labels": True})