From bcaf26b18c0e96879dd979b8832aacf060ebe4d6 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 2 Mar 2021 19:31:12 -0600 Subject: [PATCH] Sklearn: Handle the special case when all labels are the same --- miplearn/classifiers/__init__.py | 14 ++++++++++- tests/classifiers/test_sklearn.py | 40 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 tests/classifiers/test_sklearn.py diff --git a/miplearn/classifiers/__init__.py b/miplearn/classifiers/__init__.py index 0846c09..fc262f7 100644 --- a/miplearn/classifiers/__init__.py +++ b/miplearn/classifiers/__init__.py @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. from abc import ABC, abstractmethod -from typing import Optional, Any +from typing import Optional, Any, cast import numpy as np @@ -139,6 +139,7 @@ class ScikitLearnClassifier(Classifier): def __init__(self, clf: Any) -> None: super().__init__() self.inner_clf = clf + self.constant: Optional[np.ndarray] = None def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None: super().fit(x_train, y_train) @@ -147,11 +148,22 @@ class ScikitLearnClassifier(Classifier): f"Scikit-learn classifiers must have exactly two classes. " f"{n_classes} classes were provided instead." ) + + # When all samples belong to the same class, sklearn's predict_proba returns + # an array with a single column. The following check avoid this strange + # behavior. + mean = cast(np.ndarray, y_train.astype(float).mean(axis=0)) + if mean.max() == 1.0: + self.constant = mean + return + self.inner_clf.fit(x_train, y_train[:, 1]) def predict_proba(self, x_test: np.ndarray) -> np.ndarray: super().predict_proba(x_test) n_samples = x_test.shape[0] + if self.constant is not None: + return np.array([self.constant for n in range(n_samples)]) sklearn_proba = self.inner_clf.predict_proba(x_test) if isinstance(sklearn_proba, list): assert len(sklearn_proba) == self.n_classes diff --git a/tests/classifiers/test_sklearn.py b/tests/classifiers/test_sklearn.py new file mode 100644 index 0000000..1464b20 --- /dev/null +++ b/tests/classifiers/test_sklearn.py @@ -0,0 +1,40 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +import numpy as np +from numpy.testing import assert_array_equal +from sklearn.neighbors import KNeighborsClassifier + +from miplearn import ScikitLearnClassifier + + +def test_constant_prediction(): + x_train = np.array( + [ + [0.0, 1.0], + [1.0, 0.0], + ] + ) + y_train = np.array( + [ + [True, False], + [True, False], + ] + ) + clf = ScikitLearnClassifier( + KNeighborsClassifier( + n_neighbors=1, + ) + ) + clf.fit(x_train, y_train) + proba = clf.predict_proba(x_train) + assert_array_equal( + proba, + np.array( + [ + [1.0, 0.0], + [1.0, 0.0], + ] + ), + )