mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Objective: Rewrite sample_evaluate
This commit is contained in:
@@ -4,7 +4,12 @@
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def classifier_evaluation_dict(tp: int, tn: int, fp: int, fn: int) -> Dict:
|
||||
def classifier_evaluation_dict(
|
||||
tp: int,
|
||||
tn: int,
|
||||
fp: int,
|
||||
fn: int,
|
||||
) -> Dict[str, float]:
|
||||
p = tp + fn
|
||||
n = fp + tn
|
||||
d: Dict = {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional
|
||||
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict, Optional, Hashable
|
||||
|
||||
from miplearn.extractors import InstanceIterator
|
||||
from miplearn.instance import Instance
|
||||
@@ -205,5 +205,9 @@ class Component:
|
||||
ev += [self.sample_evaluate(instance.features, sample)]
|
||||
return ev
|
||||
|
||||
def sample_evaluate(self, features: Features, sample: TrainingSample) -> Dict:
|
||||
def sample_evaluate(
|
||||
self,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
return {}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple
|
||||
from typing import List, Dict, Union, Optional, Any, TYPE_CHECKING, Tuple, Hashable
|
||||
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LinearRegression
|
||||
@@ -149,3 +149,25 @@ class ObjectiveValueComponent(Component):
|
||||
if "Upper bound" in sample and sample["Upper bound"] is not None:
|
||||
y["Upper bound"] = [[sample["Upper bound"]]]
|
||||
return x, y
|
||||
|
||||
def sample_evaluate(
|
||||
self,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
def compare(y_pred: float, y_actual: float) -> Dict[str, float]:
|
||||
err = np.round(abs(y_pred - y_actual), 8)
|
||||
return {
|
||||
"Actual value": y_actual,
|
||||
"Predicted value": y_pred,
|
||||
"Absolute error": err,
|
||||
"Relative error": err / y_actual,
|
||||
}
|
||||
|
||||
result: Dict[Hashable, Dict[str, float]] = {}
|
||||
pred = self.sample_predict(features, sample)
|
||||
if "Upper bound" in sample and sample["Upper bound"] is not None:
|
||||
result["Upper bound"] = compare(pred["Upper bound"], sample["Upper bound"])
|
||||
if "Lower bound" in sample and sample["Lower bound"] is not None:
|
||||
result["Lower bound"] = compare(pred["Lower bound"], sample["Lower bound"])
|
||||
return result
|
||||
|
||||
@@ -186,7 +186,7 @@ class PrimalSolutionComponent(Component):
|
||||
self,
|
||||
features: Features,
|
||||
sample: TrainingSample,
|
||||
) -> Dict:
|
||||
) -> Dict[Hashable, Dict[str, float]]:
|
||||
solution_actual = sample["Solution"]
|
||||
assert solution_actual is not None
|
||||
solution_pred = self.sample_predict(features, sample)
|
||||
|
||||
Reference in New Issue
Block a user