mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Update
This commit is contained in:
@@ -38,7 +38,7 @@ class Classifier(ABC):
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.float64,
|
||||
], f"x_train.dtype shoule be float. Found {x_train.dtype} instead."
|
||||
], f"x_train.dtype should be float. Found {x_train.dtype} instead."
|
||||
assert y_train.dtype == np.bool8
|
||||
assert len(x_train.shape) == 2
|
||||
assert len(y_train.shape) == 2
|
||||
|
||||
@@ -6,8 +6,10 @@ import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
@@ -63,6 +65,15 @@ class AdaptiveClassifier(Classifier):
|
||||
super().__init__()
|
||||
if candidates is None:
|
||||
candidates = {
|
||||
"forest(5,10)": CandidateClassifierSpecs(
|
||||
classifier=ScikitLearnClassifier(
|
||||
RandomForestClassifier(
|
||||
n_estimators=5,
|
||||
min_samples_split=10,
|
||||
),
|
||||
),
|
||||
min_samples=100,
|
||||
),
|
||||
"knn(100)": CandidateClassifierSpecs(
|
||||
classifier=ScikitLearnClassifier(
|
||||
KNeighborsClassifier(n_neighbors=100)
|
||||
@@ -92,7 +103,7 @@ class AdaptiveClassifier(Classifier):
|
||||
|
||||
# If almost all samples belong to the same class, return a fixed prediction and
|
||||
# skip all the other steps.
|
||||
if y_train[:, 0].mean() > 0.999 or y_train[:, 1].mean() > 0.999:
|
||||
if y_train[:, 0].mean() > 0.99 or y_train[:, 1].mean() > 0.99:
|
||||
self.classifier = CountingClassifier()
|
||||
self.classifier.fit(x_train, y_train)
|
||||
return
|
||||
@@ -102,13 +113,17 @@ class AdaptiveClassifier(Classifier):
|
||||
if n_samples < specs.min_samples:
|
||||
continue
|
||||
clf = specs.classifier.clone()
|
||||
clf.fit(x_train, y_train)
|
||||
proba = clf.predict_proba(x_train)
|
||||
# FIXME: Switch to k-fold cross validation
|
||||
score = roc_auc_score(y_train[:, 1], proba[:, 1])
|
||||
if isinstance(clf, ScikitLearnClassifier):
|
||||
proba = cross_val_predict(clf.inner_clf, x_train, y_train[:, 1])
|
||||
else:
|
||||
clf.fit(x_train, y_train)
|
||||
proba = clf.predict_proba(x_train)[:, 1]
|
||||
score = roc_auc_score(y_train[:, 1], proba)
|
||||
if score > best_score:
|
||||
best_name, best_clf, best_score = name, clf, score
|
||||
logger.debug("Best classifier: %s (score=%.3f)" % (best_name, best_score))
|
||||
if isinstance(best_clf, ScikitLearnClassifier):
|
||||
best_clf.fit(x_train, y_train)
|
||||
self.classifier = best_clf
|
||||
|
||||
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
||||
|
||||
@@ -7,7 +7,10 @@ from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
from sklearn.metrics._ranking import _binary_clf_curve
|
||||
from sklearn.model_selection import cross_val_predict
|
||||
|
||||
from miplearn.classifiers.sklearn import ScikitLearnClassifier
|
||||
from miplearn.classifiers.adaptive import AdaptiveClassifier
|
||||
from miplearn.classifiers import Classifier
|
||||
|
||||
|
||||
@@ -95,7 +98,17 @@ class MinPrecisionThreshold(Threshold):
|
||||
) -> None:
|
||||
super().fit(clf, x_train, y_train)
|
||||
(n_samples, n_classes) = y_train.shape
|
||||
proba = clf.predict_proba(x_train)
|
||||
if isinstance(clf, AdaptiveClassifier) and isinstance(
|
||||
clf.classifier, ScikitLearnClassifier
|
||||
):
|
||||
proba = cross_val_predict(
|
||||
clf.classifier.inner_clf,
|
||||
x_train,
|
||||
y_train[:, 1],
|
||||
method="predict_proba",
|
||||
)
|
||||
else:
|
||||
proba = clf.predict_proba(x_train)
|
||||
self._computed_threshold = [
|
||||
self._compute(
|
||||
y_train[:, i],
|
||||
@@ -114,11 +127,13 @@ class MinPrecisionThreshold(Threshold):
|
||||
y_actual: np.ndarray,
|
||||
y_prob: np.ndarray,
|
||||
min_precision: float,
|
||||
min_recall: float = 0.1,
|
||||
) -> float:
|
||||
fps, tps, thresholds = _binary_clf_curve(y_actual, y_prob)
|
||||
precision = tps / (tps + fps)
|
||||
recall = tps / tps[-1]
|
||||
for k in reversed(range(len(precision))):
|
||||
if precision[k] >= min_precision:
|
||||
if precision[k] >= min_precision and recall[k] >= min_recall:
|
||||
return thresholds[k]
|
||||
return float("inf")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user