mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Refactor PrimalSolutionComponent
This commit is contained in:
@@ -142,8 +142,11 @@ class ScikitLearnClassifier(Classifier):
|
||||
|
||||
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
super().fit(x_train, y_train)
|
||||
(n_samples, n_classes) = x_train.shape
|
||||
assert n_classes == 2, "scikit-learn classifiers must have exactly two classes"
|
||||
(n_samples, n_classes) = y_train.shape
|
||||
assert n_classes == 2, (
|
||||
f"Scikit-learn classifiers must have exactly two classes. "
|
||||
f"{n_classes} classes were provided instead."
|
||||
)
|
||||
self.inner_clf.fit(x_train, y_train[:, 1])
|
||||
|
||||
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics._ranking import _binary_clf_curve
|
||||
@@ -42,10 +42,10 @@ class Threshold(ABC):
|
||||
assert y_train.shape[0] == n_samples
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, x_test: np.ndarray) -> float:
|
||||
def predict(self, x_test: np.ndarray) -> List[float]:
|
||||
"""
|
||||
Returns the minimum probability for a machine learning prediction to be
|
||||
considered trustworthy.
|
||||
considered trustworthy. There is one value for each label.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -56,13 +56,13 @@ class MinProbabilityThreshold(Threshold):
|
||||
correct, as computed by the machine learning models, are above a fixed value.
|
||||
"""
|
||||
|
||||
def __init__(self, min_probability: float):
|
||||
def __init__(self, min_probability: List[float]):
|
||||
self.min_probability = min_probability
|
||||
|
||||
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def predict(self, x_test: np.ndarray) -> float:
|
||||
def predict(self, x_test: np.ndarray) -> List[float]:
|
||||
return self.min_probability
|
||||
|
||||
|
||||
@@ -73,21 +73,41 @@ class MinPrecisionThreshold(Threshold):
|
||||
set. Note that increasing a component's minimum precision may reduce its recall.
|
||||
"""
|
||||
|
||||
def __init__(self, min_precision: float) -> None:
|
||||
def __init__(self, min_precision: List[float]) -> None:
|
||||
self.min_precision = min_precision
|
||||
self._computed_threshold: Optional[float] = None
|
||||
self._computed_threshold: Optional[List[float]] = None
|
||||
|
||||
def fit(self, clf: Classifier, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
def fit(
|
||||
self,
|
||||
clf: Classifier,
|
||||
x_train: np.ndarray,
|
||||
y_train: np.ndarray,
|
||||
) -> None:
|
||||
super().fit(clf, x_train, y_train)
|
||||
(n_samples, n_classes) = y_train.shape
|
||||
proba = clf.predict_proba(x_train)
|
||||
fps, tps, thresholds = _binary_clf_curve(y_train, proba[:, 1])
|
||||
precision = tps / (tps + fps)
|
||||
for k in reversed(range(len(precision))):
|
||||
if precision[k] >= self.min_precision:
|
||||
self._computed_threshold = thresholds[k]
|
||||
return
|
||||
self._computed_threshold = float("inf")
|
||||
self._computed_threshold = [
|
||||
self._compute(
|
||||
y_train[:, i],
|
||||
proba[:, i],
|
||||
self.min_precision[i],
|
||||
)
|
||||
for i in range(n_classes)
|
||||
]
|
||||
|
||||
def predict(self, x_test: np.ndarray) -> float:
|
||||
def predict(self, x_test: np.ndarray) -> List[float]:
|
||||
assert self._computed_threshold is not None
|
||||
return self._computed_threshold
|
||||
|
||||
@staticmethod
|
||||
def _compute(
|
||||
y_actual: np.ndarray,
|
||||
y_prob: np.ndarray,
|
||||
min_precision: float,
|
||||
) -> float:
|
||||
fps, tps, thresholds = _binary_clf_curve(y_actual, y_prob)
|
||||
precision = tps / (tps + fps)
|
||||
for k in reversed(range(len(precision))):
|
||||
if precision[k] >= min_precision:
|
||||
return thresholds[k]
|
||||
return float("inf")
|
||||
|
||||
Reference in New Issue
Block a user