mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Refactor thresholds
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics._ranking import _binary_clf_curve
|
||||
@@ -10,47 +11,83 @@ from sklearn.metrics._ranking import _binary_clf_curve
|
||||
from miplearn.classifiers import Classifier
|
||||
|
||||
|
||||
class DynamicThreshold(ABC):
|
||||
class Threshold(ABC):
|
||||
"""
|
||||
Solver components ask the machine learning models how confident are they on each
|
||||
prediction they make, then automatically discard all predictions that have low
|
||||
confidence. A Threshold specifies how confident should the ML models be for a
|
||||
prediction to be considered trustworthy.
|
||||
|
||||
To model dynamic thresholds, which automatically adjust themselves during
|
||||
training to reach some desired target (such as minimum precision, or minimum
|
||||
recall), thresholds behave somewhat similar to ML models themselves, with `fit`
|
||||
and `predict` methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def find(
|
||||
def fit(
|
||||
self,
|
||||
clf: Classifier,
|
||||
x_train: np.ndarray,
|
||||
y_train: np.ndarray,
|
||||
) -> float:
|
||||
) -> None:
|
||||
"""
|
||||
Given a trained binary classifier `clf` and a training data set,
|
||||
returns the numerical threshold (float) satisfying some criterea.
|
||||
Given a trained binary classifier `clf`, calibrates itself based on the
|
||||
classifier's performance on the given training data set.
|
||||
"""
|
||||
assert isinstance(clf, Classifier)
|
||||
assert isinstance(x_train, np.ndarray)
|
||||
assert isinstance(y_train, np.ndarray)
|
||||
n_samples = x_train.shape[0]
|
||||
assert y_train.shape[0] == n_samples
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, x_test: np.ndarray) -> float:
|
||||
"""
|
||||
Returns the minimum probability for a machine learning prediction to be
|
||||
considered trustworthy.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MinPrecisionThreshold(DynamicThreshold):
|
||||
class MinProbabilityThreshold(Threshold):
|
||||
"""
|
||||
The smallest possible threshold satisfying a minimum acceptable true
|
||||
positive rate (also known as precision).
|
||||
A threshold which considers predictions trustworthy if their probability of being
|
||||
correct, as computed by the machine learning models, are above a fixed value.
|
||||
"""
|
||||
|
||||
def __init__(self, min_probability: float):
|
||||
self.min_probability = min_probability
|
||||
|
||||
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def predict(self, x_test: np.ndarray) -> float:
|
||||
return self.min_probability
|
||||
|
||||
|
||||
class MinPrecisionThreshold(Threshold):
|
||||
"""
|
||||
A dynamic threshold which automatically adjusts itself during training to ensure
|
||||
that the component achieves at least a given precision `p` on the training data
|
||||
set. Note that increasing a component's minimum precision may reduce its recall.
|
||||
"""
|
||||
|
||||
def __init__(self, min_precision: float) -> None:
|
||||
self.min_precision = min_precision
|
||||
self._computed_threshold: Optional[float] = None
|
||||
|
||||
def find(self, clf, x_train, y_train):
|
||||
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
super().fit(clf, x_train, y_train)
|
||||
proba = clf.predict_proba(x_train)
|
||||
|
||||
assert isinstance(proba, np.ndarray), "classifier should return numpy array"
|
||||
assert proba.shape == (
|
||||
x_train.shape[0],
|
||||
2,
|
||||
), "classifier should return (%d,%d)-shaped array, not %s" % (
|
||||
x_train.shape[0],
|
||||
2,
|
||||
str(proba.shape),
|
||||
)
|
||||
|
||||
fps, tps, thresholds = _binary_clf_curve(y_train, proba[:, 1])
|
||||
precision = tps / (tps + fps)
|
||||
|
||||
for k in reversed(range(len(precision))):
|
||||
if precision[k] >= self.min_precision:
|
||||
return thresholds[k]
|
||||
return 2.0
|
||||
self._computed_threshold = thresholds[k]
|
||||
return
|
||||
self._computed_threshold = float("inf")
|
||||
|
||||
def predict(self, x_test: np.ndarray) -> float:
|
||||
assert self._computed_threshold is not None
|
||||
return self._computed_threshold
|
||||
|
||||
@@ -11,7 +11,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.adaptive import AdaptiveClassifier
|
||||
from miplearn.classifiers.threshold import MinPrecisionThreshold, DynamicThreshold
|
||||
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.extractors import VariableFeaturesExtractor, SolutionExtractor, Extractor
|
||||
@@ -28,11 +28,11 @@ class PrimalSolutionComponent(Component):
|
||||
self,
|
||||
classifier: Classifier = AdaptiveClassifier(),
|
||||
mode: str = "exact",
|
||||
threshold: Union[float, DynamicThreshold] = MinPrecisionThreshold(0.98),
|
||||
threshold: Union[float, Threshold] = MinPrecisionThreshold(0.98),
|
||||
) -> None:
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[Any, Classifier] = {}
|
||||
self.thresholds: Dict[Any, Union[float, DynamicThreshold]] = {}
|
||||
self.thresholds: Dict[Any, Union[float, Threshold]] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
|
||||
@@ -89,8 +89,8 @@ class PrimalSolutionComponent(Component):
|
||||
clf.fit(x_train, y_train)
|
||||
|
||||
# Find threshold (dynamic or static)
|
||||
if isinstance(self.threshold_prototype, DynamicThreshold):
|
||||
self.thresholds[category, label] = self.threshold_prototype.find(
|
||||
if isinstance(self.threshold_prototype, Threshold):
|
||||
self.thresholds[category, label] = self.threshold_prototype.fit(
|
||||
clf,
|
||||
x_train,
|
||||
y_train,
|
||||
|
||||
Reference in New Issue
Block a user