You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
MIPLearn/miplearn/classifiers/cv.py

133 lines
5.0 KiB

# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Optional, List
import numpy as np
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from miplearn.classifiers import Classifier
from miplearn.classifiers.sklearn import ScikitLearnClassifier
logger = logging.getLogger(__name__)
class CrossValidatedClassifier(Classifier):
"""
A meta-classifier that, upon training, evaluates the performance of another
candidate classifier on the training data set, using k-fold cross validation,
then either adopts it, if its cv-score is high enough, or returns constant
predictions for every x_test, otherwise.
Parameters
----------
classifier: Callable[[], ScikitLearnClassifier]
A callable that constructs the candidate classifier.
threshold: float
Number from zero to one indicating how well must the candidate classifier
perform to be adopted. The threshold is specified in comparison to a dummy
classifier trained on the same dataset. For example, a threshold of 0.0
indicates that any classifier as good as the dummy predictor is acceptable. A
threshold of 1.0 indicates that only classifiers with perfect
cross-validation scores are acceptable. Other numbers are a linear
interpolation of these two extremes.
constant: Optional[List[bool]]
If the candidate classifier fails to meet the threshold, use a dummy classifier
which always returns this prediction instead. The list should have exactly as
many elements as the number of columns of `x_train` provided to `fit`.
cv: int
Number of folds.
scoring: str
Scoring function.
"""
def __init__(
self,
classifier: ScikitLearnClassifier = ScikitLearnClassifier(LogisticRegression()),
threshold: float = 0.75,
constant: Optional[List[bool]] = None,
cv: int = 5,
scoring: str = "accuracy",
):
super().__init__()
if constant is None:
constant = [True, False]
self.n_classes = len(constant)
self.classifier: Optional[ScikitLearnClassifier] = None
self.classifier_prototype = classifier
self.constant: List[bool] = constant
self.threshold = threshold
self.cv = cv
self.scoring = scoring
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
super().fit(x_train, y_train)
(n_samples, n_classes) = x_train.shape
assert n_classes == self.n_classes
# Calculate dummy score and absolute score threshold
y_train_avg = np.average(y_train)
dummy_score = max(y_train_avg, 1 - y_train_avg)
absolute_threshold = 1.0 * self.threshold + dummy_score * (1 - self.threshold)
# Calculate cross validation score and decide which classifier to use
clf = self.classifier_prototype.clone()
assert clf is not None
assert isinstance(clf, ScikitLearnClassifier), (
f"The provided classifier callable must return a ScikitLearnClassifier. "
f"Found {clf.__class__.__name__} instead. If this is a scikit-learn "
f"classifier, you must wrap it with ScikitLearnClassifier."
)
cv_score = float(
np.mean(
cross_val_score(
clf.inner_clf,
x_train,
y_train[:, 1],
cv=self.cv,
scoring=self.scoring,
)
)
)
if cv_score >= absolute_threshold:
logger.debug(
"cv_score is above threshold (%.2f >= %.2f); keeping"
% (cv_score, absolute_threshold)
)
self.classifier = clf
else:
logger.debug(
"cv_score is below threshold (%.2f < %.2f); discarding"
% (cv_score, absolute_threshold)
)
self.classifier = ScikitLearnClassifier(
DummyClassifier(
strategy="constant",
constant=self.constant[1],
)
)
# Train chosen classifier
assert self.classifier is not None
assert isinstance(self.classifier, ScikitLearnClassifier)
self.classifier.fit(x_train, y_train)
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
super().predict_proba(x_test)
assert self.classifier is not None
return self.classifier.predict_proba(x_test)
def clone(self) -> "CrossValidatedClassifier":
return CrossValidatedClassifier(
classifier=self.classifier_prototype,
threshold=self.threshold,
constant=self.constant,
cv=self.cv,
scoring=self.scoring,
)