Objective: Rewrite sample_evaluate

This commit is contained in:
2021-04-03 18:37:03 -05:00
parent 7af22bd16b
commit 185b95118a
5 changed files with 59 additions and 95 deletions

View File

@@ -4,7 +4,12 @@
from typing import Dict
def classifier_evaluation_dict(tp: int, tn: int, fp: int, fn: int) -> Dict:
def classifier_evaluation_dict(
tp: int,
tn: int,
fp: int,
fn: int,
) -> Dict[str, float]:
p = tp + fn
n = fp + tn
d: Dict = {

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import numpy as np
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional, Hashable
from miplearn.extractors import InstanceIterator
from miplearn.instance import Instance
@@ -205,5 +205,9 @@ class Component:
ev += [self.sample_evaluate(instance.features, sample)]
return ev
def sample_evaluate(self, features: Features, sample: TrainingSample) -> Dict:
def sample_evaluate(
self,
features: Features,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
return {}

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple, Hashable
import numpy as np
from sklearn.linear_model import LinearRegression
@@ -149,3 +149,25 @@ class ObjectiveValueComponent(Component):
if "Upper bound" in sample and sample["Upper bound"] is not None:
y["Upper bound"] = [[sample["Upper bound"]]]
return x, y
def sample_evaluate(
self,
features: Features,
sample: TrainingSample,
) -> Dict[Hashable, Dict[str, float]]:
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
err = np.round(abs(y_pred - y_actual), 8)
return {
"Actual value": y_actual,
"Predicted value": y_pred,
"Absolute error": err,
"Relative error": err / y_actual,
}
result: Dict[Hashable, Dict[str, float]] = {}
pred = self.sample_predict(features, sample)
if "Upper bound" in sample and sample["Upper bound"] is not None:
result["Upper bound"] = compare(pred["Upper bound"], sample["Upper bound"])
if "Lower bound" in sample and sample["Lower bound"] is not None:
result["Lower bound"] = compare(pred["Lower bound"], sample["Lower bound"])
return result

View File

@@ -186,7 +186,7 @@ class PrimalSolutionComponent(Component):
self,
features: Features,
sample: TrainingSample,
) -> Dict:
) -> Dict[Hashable, Dict[str, float]]:
solution_actual = sample["Solution"]
assert solution_actual is not None
solution_pred = self.sample_predict(features, sample)