mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Objective: Rewrite sample_evaluate
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple
|
||||
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple, Hashable
|
||||
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LinearRegression
|
||||
@@ -149,3 +149,25 @@ class ObjectiveValueComponent(Component):
|
||||
if "Upper bound" in sample and sample["Upper bound"] is not None:
|
||||
y["Upper bound"] = [[sample["Upper bound"]]]
|
||||
return x, y
|
||||
|
||||
def sample_evaluate(
|
||||
self,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
|
||||
err = np.round(abs(y_pred - y_actual), 8)
|
||||
return {
|
||||
"Actual value": y_actual,
|
||||
"Predicted value": y_pred,
|
||||
"Absolute error": err,
|
||||
"Relative error": err / y_actual,
|
||||
}
|
||||
|
||||
result: Dict[Hashable, Dict[str, float]] = {}
|
||||
pred = self.sample_predict(features, sample)
|
||||
if "Upper bound" in sample and sample["Upper bound"] is not None:
|
||||
result["Upper bound"] = compare(pred["Upper bound"], sample["Upper bound"])
|
||||
if "Lower bound" in sample and sample["Lower bound"] is not None:
|
||||
result["Lower bound"] = compare(pred["Lower bound"], sample["Lower bound"])
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user