mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Fix mypy errors
This commit is contained in:
@@ -22,7 +22,7 @@ class AlvLouWeh2017Extractor(FeaturesExtractor):
|
||||
self.with_m3 = with_m3
|
||||
|
||||
def get_instance_features(self, h5: H5File) -> np.ndarray:
|
||||
raise NotImplemented()
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_var_features(self, h5: H5File) -> np.ndarray:
|
||||
"""
|
||||
@@ -197,7 +197,7 @@ class AlvLouWeh2017Extractor(FeaturesExtractor):
|
||||
return features
|
||||
|
||||
def get_constr_features(self, h5: H5File) -> np.ndarray:
|
||||
raise NotImplemented()
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _fix_infinity(m: Optional[np.ndarray]) -> None:
|
||||
|
||||
@@ -31,9 +31,9 @@ class H5FieldsExtractor(FeaturesExtractor):
|
||||
data = h5.get_scalar(field)
|
||||
assert data is not None
|
||||
x.append(data)
|
||||
x = np.hstack(x)
|
||||
assert len(x.shape) == 1
|
||||
return x
|
||||
x_np = np.hstack(x)
|
||||
assert len(x_np.shape) == 1
|
||||
return x_np
|
||||
|
||||
def get_var_features(self, h5: H5File) -> np.ndarray:
|
||||
var_types = h5.get_array("static_var_types")
|
||||
@@ -51,13 +51,14 @@ class H5FieldsExtractor(FeaturesExtractor):
|
||||
raise Exception("No constr fields provided")
|
||||
return self._extract(h5, self.constr_fields, n_constr)
|
||||
|
||||
def _extract(self, h5, fields, n_expected):
|
||||
def _extract(self, h5: H5File, fields: List[str], n_expected: int) -> np.ndarray:
|
||||
x = []
|
||||
for field in fields:
|
||||
try:
|
||||
data = h5.get_array(field)
|
||||
except ValueError:
|
||||
v = h5.get_scalar(field)
|
||||
assert v is not None
|
||||
data = np.repeat(v, n_expected)
|
||||
assert data is not None
|
||||
assert len(data.shape) == 1
|
||||
|
||||
Reference in New Issue
Block a user