mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Sklearn: Handle the special case when all labels are the same
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -139,6 +139,7 @@ class ScikitLearnClassifier(Classifier):
|
||||
def __init__(self, clf: Any) -> None:
|
||||
super().__init__()
|
||||
self.inner_clf = clf
|
||||
self.constant: Optional[np.ndarray] = None
|
||||
|
||||
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
super().fit(x_train, y_train)
|
||||
@@ -147,11 +148,22 @@ class ScikitLearnClassifier(Classifier):
|
||||
f"Scikit-learn classifiers must have exactly two classes. "
|
||||
f"{n_classes} classes were provided instead."
|
||||
)
|
||||
|
||||
# When all samples belong to the same class, sklearn's predict_proba returns
|
||||
# an array with a single column. The following check avoid this strange
|
||||
# behavior.
|
||||
mean = cast(np.ndarray, y_train.astype(float).mean(axis=0))
|
||||
if mean.max() == 1.0:
|
||||
self.constant = mean
|
||||
return
|
||||
|
||||
self.inner_clf.fit(x_train, y_train[:, 1])
|
||||
|
||||
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
||||
super().predict_proba(x_test)
|
||||
n_samples = x_test.shape[0]
|
||||
if self.constant is not None:
|
||||
return np.array([self.constant for n in range(n_samples)])
|
||||
sklearn_proba = self.inner_clf.predict_proba(x_test)
|
||||
if isinstance(sklearn_proba, list):
|
||||
assert len(sklearn_proba) == self.n_classes
|
||||
|
||||
Reference in New Issue
Block a user