You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
5.0 KiB
133 lines
5.0 KiB
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
import logging
|
|
from typing import Optional, List
|
|
|
|
import numpy as np
|
|
from sklearn.dummy import DummyClassifier
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.model_selection import cross_val_score
|
|
|
|
from miplearn.classifiers import Classifier
|
|
from miplearn.classifiers.sklearn import ScikitLearnClassifier
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CrossValidatedClassifier(Classifier):
|
|
"""
|
|
A meta-classifier that, upon training, evaluates the performance of another
|
|
candidate classifier on the training data set, using k-fold cross validation,
|
|
then either adopts it, if its cv-score is high enough, or returns constant
|
|
predictions for every x_test, otherwise.
|
|
|
|
Parameters
|
|
----------
|
|
classifier: Callable[[], ScikitLearnClassifier]
|
|
A callable that constructs the candidate classifier.
|
|
threshold: float
|
|
Number from zero to one indicating how well must the candidate classifier
|
|
perform to be adopted. The threshold is specified in comparison to a dummy
|
|
classifier trained on the same dataset. For example, a threshold of 0.0
|
|
indicates that any classifier as good as the dummy predictor is acceptable. A
|
|
threshold of 1.0 indicates that only classifiers with perfect
|
|
cross-validation scores are acceptable. Other numbers are a linear
|
|
interpolation of these two extremes.
|
|
constant: Optional[List[bool]]
|
|
If the candidate classifier fails to meet the threshold, use a dummy classifier
|
|
which always returns this prediction instead. The list should have exactly as
|
|
many elements as the number of columns of `x_train` provided to `fit`.
|
|
cv: int
|
|
Number of folds.
|
|
scoring: str
|
|
Scoring function.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
classifier: ScikitLearnClassifier = ScikitLearnClassifier(LogisticRegression()),
|
|
threshold: float = 0.75,
|
|
constant: Optional[List[bool]] = None,
|
|
cv: int = 5,
|
|
scoring: str = "accuracy",
|
|
):
|
|
super().__init__()
|
|
if constant is None:
|
|
constant = [True, False]
|
|
self.n_classes = len(constant)
|
|
self.classifier: Optional[ScikitLearnClassifier] = None
|
|
self.classifier_prototype = classifier
|
|
self.constant: List[bool] = constant
|
|
self.threshold = threshold
|
|
self.cv = cv
|
|
self.scoring = scoring
|
|
|
|
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
|
super().fit(x_train, y_train)
|
|
(n_samples, n_classes) = x_train.shape
|
|
assert n_classes == self.n_classes
|
|
|
|
# Calculate dummy score and absolute score threshold
|
|
y_train_avg = np.average(y_train)
|
|
dummy_score = max(y_train_avg, 1 - y_train_avg)
|
|
absolute_threshold = 1.0 * self.threshold + dummy_score * (1 - self.threshold)
|
|
|
|
# Calculate cross validation score and decide which classifier to use
|
|
clf = self.classifier_prototype.clone()
|
|
assert clf is not None
|
|
assert isinstance(clf, ScikitLearnClassifier), (
|
|
f"The provided classifier callable must return a ScikitLearnClassifier. "
|
|
f"Found {clf.__class__.__name__} instead. If this is a scikit-learn "
|
|
f"classifier, you must wrap it with ScikitLearnClassifier."
|
|
)
|
|
|
|
cv_score = float(
|
|
np.mean(
|
|
cross_val_score(
|
|
clf.inner_clf,
|
|
x_train,
|
|
y_train[:, 1],
|
|
cv=self.cv,
|
|
scoring=self.scoring,
|
|
)
|
|
)
|
|
)
|
|
if cv_score >= absolute_threshold:
|
|
logger.debug(
|
|
"cv_score is above threshold (%.2f >= %.2f); keeping"
|
|
% (cv_score, absolute_threshold)
|
|
)
|
|
self.classifier = clf
|
|
else:
|
|
logger.debug(
|
|
"cv_score is below threshold (%.2f < %.2f); discarding"
|
|
% (cv_score, absolute_threshold)
|
|
)
|
|
self.classifier = ScikitLearnClassifier(
|
|
DummyClassifier(
|
|
strategy="constant",
|
|
constant=self.constant[1],
|
|
)
|
|
)
|
|
|
|
# Train chosen classifier
|
|
assert self.classifier is not None
|
|
assert isinstance(self.classifier, ScikitLearnClassifier)
|
|
self.classifier.fit(x_train, y_train)
|
|
|
|
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
|
super().predict_proba(x_test)
|
|
assert self.classifier is not None
|
|
return self.classifier.predict_proba(x_test)
|
|
|
|
def clone(self) -> "CrossValidatedClassifier":
|
|
return CrossValidatedClassifier(
|
|
classifier=self.classifier_prototype,
|
|
threshold=self.threshold,
|
|
constant=self.constant,
|
|
cv=self.cv,
|
|
scoring=self.scoring,
|
|
)
|