mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-11 20:08:52 -06:00
Fix mypy errors
This commit is contained in:
@@ -31,9 +31,9 @@ class H5FieldsExtractor(FeaturesExtractor):
|
||||
data = h5.get_scalar(field)
|
||||
assert data is not None
|
||||
x.append(data)
|
||||
x = np.hstack(x)
|
||||
assert len(x.shape) == 1
|
||||
return x
|
||||
x_np = np.hstack(x)
|
||||
assert len(x_np.shape) == 1
|
||||
return x_np
|
||||
|
||||
def get_var_features(self, h5: H5File) -> np.ndarray:
|
||||
var_types = h5.get_array("static_var_types")
|
||||
@@ -51,13 +51,14 @@ class H5FieldsExtractor(FeaturesExtractor):
|
||||
raise Exception("No constr fields provided")
|
||||
return self._extract(h5, self.constr_fields, n_constr)
|
||||
|
||||
def _extract(self, h5, fields, n_expected):
|
||||
def _extract(self, h5: H5File, fields: List[str], n_expected: int) -> np.ndarray:
|
||||
x = []
|
||||
for field in fields:
|
||||
try:
|
||||
data = h5.get_array(field)
|
||||
except ValueError:
|
||||
v = h5.get_scalar(field)
|
||||
assert v is not None
|
||||
data = np.repeat(v, n_expected)
|
||||
assert data is not None
|
||||
assert len(data.shape) == 1
|
||||
|
||||
Reference in New Issue
Block a user