mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Redesign component.evaluate
This commit is contained in:
@@ -116,45 +116,45 @@ class ObjectiveValueComponent(Component):
|
||||
"Upper bound": np.array(ub),
|
||||
}
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
instances: Union[List[str], List[Instance]],
|
||||
) -> Dict[str, Dict[str, float]]:
|
||||
y_pred = self.predict(instances)
|
||||
y_true = np.array(
|
||||
[
|
||||
[
|
||||
inst.training_data[0]["Lower bound"],
|
||||
inst.training_data[0]["Upper bound"],
|
||||
]
|
||||
for inst in InstanceIterator(instances)
|
||||
]
|
||||
)
|
||||
y_pred_lb = y_pred["Lower bound"]
|
||||
y_pred_ub = y_pred["Upper bound"]
|
||||
y_true_lb, y_true_ub = y_true[:, 1], y_true[:, 1]
|
||||
ev = {
|
||||
"Lower bound": {
|
||||
"Mean squared error": mean_squared_error(y_true_lb, y_pred_lb),
|
||||
"Explained variance": explained_variance_score(y_true_lb, y_pred_lb),
|
||||
"Max error": max_error(y_true_lb, y_pred_lb),
|
||||
"Mean absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
|
||||
"R2": r2_score(y_true_lb, y_pred_lb),
|
||||
"Median absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
|
||||
},
|
||||
"Upper bound": {
|
||||
"Mean squared error": mean_squared_error(y_true_ub, y_pred_ub),
|
||||
"Explained variance": explained_variance_score(y_true_ub, y_pred_ub),
|
||||
"Max error": max_error(y_true_ub, y_pred_ub),
|
||||
"Mean absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
|
||||
"R2": r2_score(y_true_ub, y_pred_ub),
|
||||
"Median absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
|
||||
},
|
||||
}
|
||||
return ev
|
||||
# def evaluate(
|
||||
# self,
|
||||
# instances: Union[List[str], List[Instance]],
|
||||
# ) -> Dict[str, Dict[str, float]]:
|
||||
# y_pred = self.predict(instances)
|
||||
# y_true = np.array(
|
||||
# [
|
||||
# [
|
||||
# inst.training_data[0]["Lower bound"],
|
||||
# inst.training_data[0]["Upper bound"],
|
||||
# ]
|
||||
# for inst in InstanceIterator(instances)
|
||||
# ]
|
||||
# )
|
||||
# y_pred_lb = y_pred["Lower bound"]
|
||||
# y_pred_ub = y_pred["Upper bound"]
|
||||
# y_true_lb, y_true_ub = y_true[:, 1], y_true[:, 1]
|
||||
# ev = {
|
||||
# "Lower bound": {
|
||||
# "Mean squared error": mean_squared_error(y_true_lb, y_pred_lb),
|
||||
# "Explained variance": explained_variance_score(y_true_lb, y_pred_lb),
|
||||
# "Max error": max_error(y_true_lb, y_pred_lb),
|
||||
# "Mean absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
|
||||
# "R2": r2_score(y_true_lb, y_pred_lb),
|
||||
# "Median absolute error": mean_absolute_error(y_true_lb, y_pred_lb),
|
||||
# },
|
||||
# "Upper bound": {
|
||||
# "Mean squared error": mean_squared_error(y_true_ub, y_pred_ub),
|
||||
# "Explained variance": explained_variance_score(y_true_ub, y_pred_ub),
|
||||
# "Max error": max_error(y_true_ub, y_pred_ub),
|
||||
# "Mean absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
|
||||
# "R2": r2_score(y_true_ub, y_pred_ub),
|
||||
# "Median absolute error": mean_absolute_error(y_true_ub, y_pred_ub),
|
||||
# },
|
||||
# }
|
||||
# return ev
|
||||
|
||||
@staticmethod
|
||||
def xy(
|
||||
def sample_xy(
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
|
||||
Reference in New Issue
Block a user