mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Sklearn: Handle the special case when all labels are the same
This commit is contained in:
@@ -3,7 +3,7 @@
|
|||||||
# Released under the modified BSD license. See COPYING.md for more details.
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -139,6 +139,7 @@ class ScikitLearnClassifier(Classifier):
|
|||||||
def __init__(self, clf: Any) -> None:
|
def __init__(self, clf: Any) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_clf = clf
|
self.inner_clf = clf
|
||||||
|
self.constant: Optional[np.ndarray] = None
|
||||||
|
|
||||||
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||||
super().fit(x_train, y_train)
|
super().fit(x_train, y_train)
|
||||||
@@ -147,11 +148,22 @@ class ScikitLearnClassifier(Classifier):
|
|||||||
f"Scikit-learn classifiers must have exactly two classes. "
|
f"Scikit-learn classifiers must have exactly two classes. "
|
||||||
f"{n_classes} classes were provided instead."
|
f"{n_classes} classes were provided instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# When all samples belong to the same class, sklearn's predict_proba returns
|
||||||
|
# an array with a single column. The following check avoid this strange
|
||||||
|
# behavior.
|
||||||
|
mean = cast(np.ndarray, y_train.astype(float).mean(axis=0))
|
||||||
|
if mean.max() == 1.0:
|
||||||
|
self.constant = mean
|
||||||
|
return
|
||||||
|
|
||||||
self.inner_clf.fit(x_train, y_train[:, 1])
|
self.inner_clf.fit(x_train, y_train[:, 1])
|
||||||
|
|
||||||
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
||||||
super().predict_proba(x_test)
|
super().predict_proba(x_test)
|
||||||
n_samples = x_test.shape[0]
|
n_samples = x_test.shape[0]
|
||||||
|
if self.constant is not None:
|
||||||
|
return np.array([self.constant for n in range(n_samples)])
|
||||||
sklearn_proba = self.inner_clf.predict_proba(x_test)
|
sklearn_proba = self.inner_clf.predict_proba(x_test)
|
||||||
if isinstance(sklearn_proba, list):
|
if isinstance(sklearn_proba, list):
|
||||||
assert len(sklearn_proba) == self.n_classes
|
assert len(sklearn_proba) == self.n_classes
|
||||||
|
|||||||
40
tests/classifiers/test_sklearn.py
Normal file
40
tests/classifiers/test_sklearn.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.testing import assert_array_equal
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
|
||||||
|
from miplearn import ScikitLearnClassifier
|
||||||
|
|
||||||
|
|
||||||
|
def test_constant_prediction():
|
||||||
|
x_train = np.array(
|
||||||
|
[
|
||||||
|
[0.0, 1.0],
|
||||||
|
[1.0, 0.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
y_train = np.array(
|
||||||
|
[
|
||||||
|
[True, False],
|
||||||
|
[True, False],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
clf = ScikitLearnClassifier(
|
||||||
|
KNeighborsClassifier(
|
||||||
|
n_neighbors=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
clf.fit(x_train, y_train)
|
||||||
|
proba = clf.predict_proba(x_train)
|
||||||
|
assert_array_equal(
|
||||||
|
proba,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[1.0, 0.0],
|
||||||
|
[1.0, 0.0],
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user