Refactor thresholds

This commit is contained in:
2021-01-25 09:52:49 -06:00
parent 4da561a6a8
commit f68cc5bd59
4 changed files with 82 additions and 41 deletions

View File

@@ -3,6 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import abstractmethod, ABC
from typing import Optional
import numpy as np
from sklearn.metrics._ranking import _binary_clf_curve
@@ -10,47 +11,83 @@ from sklearn.metrics._ranking import _binary_clf_curve
from miplearn.classifiers import Classifier
class DynamicThreshold(ABC):
class Threshold(ABC):
"""
Solver components ask the machine learning models how confident are they on each
prediction they make, then automatically discard all predictions that have low
confidence. A Threshold specifies how confident should the ML models be for a
prediction to be considered trustworthy.
To model dynamic thresholds, which automatically adjust themselves during
training to reach some desired target (such as minimum precision, or minimum
recall), thresholds behave somewhat similar to ML models themselves, with `fit`
and `predict` methods.
"""
@abstractmethod
def find(
def fit(
self,
clf: Classifier,
x_train: np.ndarray,
y_train: np.ndarray,
) -> float:
) -> None:
"""
Given a trained binary classifier `clf` and a training data set,
returns the numerical threshold (float) satisfying some criterea.
Given a trained binary classifier `clf`, calibrates itself based on the
classifier's performance on the given training data set.
"""
assert isinstance(clf, Classifier)
assert isinstance(x_train, np.ndarray)
assert isinstance(y_train, np.ndarray)
n_samples = x_train.shape[0]
assert y_train.shape[0] == n_samples
@abstractmethod
def predict(self, x_test: np.ndarray) -> float:
"""
Returns the minimum probability for a machine learning prediction to be
considered trustworthy.
"""
pass
class MinPrecisionThreshold(DynamicThreshold):
class MinProbabilityThreshold(Threshold):
"""
The smallest possible threshold satisfying a minimum acceptable true
positive rate (also known as precision).
A threshold which considers predictions trustworthy if their probability of being
correct, as computed by the machine learning models, are above a fixed value.
"""
def __init__(self, min_probability: float):
self.min_probability = min_probability
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
pass
def predict(self, x_test: np.ndarray) -> float:
return self.min_probability
class MinPrecisionThreshold(Threshold):
"""
A dynamic threshold which automatically adjusts itself during training to ensure
that the component achieves at least a given precision `p` on the training data
set. Note that increasing a component's minimum precision may reduce its recall.
"""
def __init__(self, min_precision: float) -> None:
self.min_precision = min_precision
self._computed_threshold: Optional[float] = None
def find(self, clf, x_train, y_train):
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
super().fit(clf, x_train, y_train)
proba = clf.predict_proba(x_train)
assert isinstance(proba, np.ndarray), "classifier should return numpy array"
assert proba.shape == (
x_train.shape[0],
2,
), "classifier should return (%d,%d)-shaped array, not %s" % (
x_train.shape[0],
2,
str(proba.shape),
)
fps, tps, thresholds = _binary_clf_curve(y_train, proba[:, 1])
precision = tps / (tps + fps)
for k in reversed(range(len(precision))):
if precision[k] >= self.min_precision:
return thresholds[k]
return 2.0
self._computed_threshold = thresholds[k]
return
self._computed_threshold = float("inf")
def predict(self, x_test: np.ndarray) -> float:
assert self._computed_threshold is not None
return self._computed_threshold