Fix mypy errors

This commit is contained in:
2023-10-26 13:39:57 -05:00
parent e555dffc0c
commit 2d07a44f7d
12 changed files with 61 additions and 41 deletions

View File

@@ -22,7 +22,7 @@ class AlvLouWeh2017Extractor(FeaturesExtractor):
self.with_m3 = with_m3
def get_instance_features(self, h5: H5File) -> np.ndarray:
raise NotImplemented()
raise NotImplementedError()
def get_var_features(self, h5: H5File) -> np.ndarray:
"""
@@ -197,7 +197,7 @@ class AlvLouWeh2017Extractor(FeaturesExtractor):
return features
def get_constr_features(self, h5: H5File) -> np.ndarray:
raise NotImplemented()
raise NotImplementedError()
def _fix_infinity(m: Optional[np.ndarray]) -> None:

View File

@@ -31,9 +31,9 @@ class H5FieldsExtractor(FeaturesExtractor):
data = h5.get_scalar(field)
assert data is not None
x.append(data)
x = np.hstack(x)
assert len(x.shape) == 1
return x
x_np = np.hstack(x)
assert len(x_np.shape) == 1
return x_np
def get_var_features(self, h5: H5File) -> np.ndarray:
var_types = h5.get_array("static_var_types")
@@ -51,13 +51,14 @@ class H5FieldsExtractor(FeaturesExtractor):
raise Exception("No constr fields provided")
return self._extract(h5, self.constr_fields, n_constr)
def _extract(self, h5, fields, n_expected):
def _extract(self, h5: H5File, fields: List[str], n_expected: int) -> np.ndarray:
x = []
for field in fields:
try:
data = h5.get_array(field)
except ValueError:
v = h5.get_scalar(field)
assert v is not None
data = np.repeat(v, n_expected)
assert data is not None
assert len(data.shape) == 1