Redesign component.evaluate

This commit is contained in:
2021-04-02 08:09:35 -05:00
parent 0c687692f7
commit 0bce2051a8
9 changed files with 221 additions and 178 deletions

View File

@@ -116,45 +116,45 @@ class ObjectiveValueComponent(Component):
"Upper bound": np.array(ub),
}
def evaluate(
self,
instances: Union[List[str], List[Instance]],
) -> Dict[str, Dict[str, float]]:
y_pred = self.predict(instances)
y_true = np.array(
[
[
inst.training_data[0]["Lower bound"],
inst.training_data[0]["Upper bound"],
]
for inst in InstanceIterator(instances)
]
)
y_pred_lb = y_pred["Lower bound"]
y_pred_ub = y_pred["Upper bound"]
y_true_lb, y_true_ub = y_true[:, 1], y_true[:, 1]
ev = {
"Lower bound": {
"Mean squared error": mean_squared_error(y_true_lb, y_pred_lb),
"Explained variance": explained_variance_score(y_true_lb, y_pred_lb),
"Max error": max_error(y_true_lb, y_pred_lb),
"Mean absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
"R2": r2_score(y_true_lb, y_pred_lb),
"Median absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
},
"Upper bound": {
"Mean squared error": mean_squared_error(y_true_ub, y_pred_ub),
"Explained variance": explained_variance_score(y_true_ub, y_pred_ub),
"Max error": max_error(y_true_ub, y_pred_ub),
"Mean absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
"R2": r2_score(y_true_ub, y_pred_ub),
"Median absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
},
}
return ev
# def evaluate(
# self,
# instances: Union[List[str], List[Instance]],
# ) -> Dict[str, Dict[str, float]]:
# y_pred = self.predict(instances)
# y_true = np.array(
# [
# [
# inst.training_data[0]["Lower bound"],
# inst.training_data[0]["Upper bound"],
# ]
# for inst in InstanceIterator(instances)
# ]
# )
# y_pred_lb = y_pred["Lower bound"]
# y_pred_ub = y_pred["Upper bound"]
# y_true_lb, y_true_ub = y_true[:, 1], y_true[:, 1]
# ev = {
# "Lower bound": {
# "Mean squared error": mean_squared_error(y_true_lb, y_pred_lb),
# "Explained variance": explained_variance_score(y_true_lb, y_pred_lb),
# "Max error": max_error(y_true_lb, y_pred_lb),
# "Mean absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
# "R2": r2_score(y_true_lb, y_pred_lb),
# "Median absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
# },
# "Upper bound": {
# "Mean squared error": mean_squared_error(y_true_ub, y_pred_ub),
# "Explained variance": explained_variance_score(y_true_ub, y_pred_ub),
# "Max error": max_error(y_true_ub, y_pred_ub),
# "Mean absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
# "R2": r2_score(y_true_ub, y_pred_ub),
# "Median absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
# },
# }
# return ev
@staticmethod
def xy(
def sample_xy(
features: Features,
sample: TrainingSample,
) -> Tuple[Dict, Dict]: