Refactor PrimalSolutionComponent

This commit is contained in:
2021-01-25 14:54:58 -06:00
parent f68cc5bd59
commit 3ab3bb3c1f
9 changed files with 501 additions and 233 deletions

View File

@@ -142,8 +142,11 @@ class ScikitLearnClassifier(Classifier):
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
super().fit(x_train, y_train)
(n_samples, n_classes) = x_train.shape
assert n_classes == 2, "scikit-learn classifiers must have exactly two classes"
(n_samples, n_classes) = y_train.shape
assert n_classes == 2, (
f"Scikit-learn classifiers must have exactly two classes. "
f"{n_classes} classes were provided instead."
)
self.inner_clf.fit(x_train, y_train[:, 1])
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import abstractmethod, ABC
from typing import Optional
from typing import Optional, List
import numpy as np
from sklearn.metrics._ranking import _binary_clf_curve
@@ -42,10 +42,10 @@ class Threshold(ABC):
assert y_train.shape[0] == n_samples
@abstractmethod
def predict(self, x_test: np.ndarray) -> float:
def predict(self, x_test: np.ndarray) -> List[float]:
"""
Returns the minimum probability for a machine learning prediction to be
considered trustworthy.
considered trustworthy. There is one value for each label.
"""
pass
@@ -56,13 +56,13 @@ class MinProbabilityThreshold(Threshold):
correct, as computed by the machine learning models, are above a fixed value.
"""
def __init__(self, min_probability: float):
def __init__(self, min_probability: List[float]):
self.min_probability = min_probability
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
pass
def predict(self, x_test: np.ndarray) -> float:
def predict(self, x_test: np.ndarray) -> List[float]:
return self.min_probability
@@ -73,21 +73,41 @@ class MinPrecisionThreshold(Threshold):
set. Note that increasing a component's minimum precision may reduce its recall.
"""
def __init__(self, min_precision: float) -> None:
def __init__(self, min_precision: List[float]) -> None:
self.min_precision = min_precision
self._computed_threshold: Optional[float] = None
self._computed_threshold: Optional[List[float]] = None
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
def fit(
self,
clf: Classifier,
x_train: np.ndarray,
y_train: np.ndarray,
) -> None:
super().fit(clf, x_train, y_train)
(n_samples, n_classes) = y_train.shape
proba = clf.predict_proba(x_train)
fps, tps, thresholds = _binary_clf_curve(y_train, proba[:, 1])
precision = tps / (tps + fps)
for k in reversed(range(len(precision))):
if precision[k] >= self.min_precision:
self._computed_threshold = thresholds[k]
return
self._computed_threshold = float("inf")
self._computed_threshold = [
self._compute(
y_train[:, i],
proba[:, i],
self.min_precision[i],
)
for i in range(n_classes)
]
def predict(self, x_test: np.ndarray) -> float:
def predict(self, x_test: np.ndarray) -> List[float]:
assert self._computed_threshold is not None
return self._computed_threshold
@staticmethod
def _compute(
y_actual: np.ndarray,
y_prob: np.ndarray,
min_precision: float,
) -> float:
fps, tps, thresholds = _binary_clf_curve(y_actual, y_prob)
precision = tps / (tps + fps)
for k in reversed(range(len(precision))):
if precision[k] >= min_precision:
return thresholds[k]
return float("inf")