Simplify AdaptiveClassifier

pull/3/head
Alinson S. Xavier 6 years ago
parent d13a548e80
commit c63a0777ed

@ -5,106 +5,57 @@
import logging import logging
from copy import deepcopy from copy import deepcopy
import numpy as np
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from sklearn.model_selection import cross_val_score from miplearn.classifiers.counting import CountingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdaptiveClassifier(Classifier): class AdaptiveClassifier(Classifier):
""" """
A classifier that automatically switches strategies based on the number of A meta-classifier which dynamically selects what actual classifier to use
samples and cross-validation scores. based on the number of samples in the training data.
By default, uses CountingClassifier for less than 30 samples and
LogisticRegression (with standard scaling) for 30 or more samples.
""" """
def __init__(self,
predictor=None, def __init__(self, classifiers=None):
min_samples_predict=1, """
min_samples_cv=100, Initializes the classifier.
thr_fix=0.999,
thr_alpha=0.50, The `classifiers` argument must be a list of tuples where the second element
thr_balance=0.95, of the tuple is the classifier and the first element is the number of
): samples required. For example, if `classifiers` is set to
self.min_samples_predict = min_samples_predict ```
self.min_samples_cv = min_samples_cv [(100, ClassifierA()),
self.thr_fix = thr_fix (50, ClassifierB()),
self.thr_alpha = thr_alpha (0, ClassifierC())]
self.thr_balance = thr_balance ``` then ClassifierA will be used if n_samples >= 100, ClassifierB will
self.predictor_factory = predictor be used if 100 > n_samples >= 50 and ClassifierC will be used if
self.predictor = None 50 > n_samples. The list must be ordered in (strictly) decreasing order.
"""
if classifiers is None:
classifiers = [
(30, make_pipeline(StandardScaler(), LogisticRegression())),
(0, CountingClassifier())
]
self.available_classifiers = classifiers
self.classifier = None
def fit(self, x_train, y_train): def fit(self, x_train, y_train):
n_samples = x_train.shape[0] n_samples = x_train.shape[0]
# If number of samples is too small, don't predict anything. for (min_samples, clf_prototype) in self.available_classifiers:
if n_samples < self.min_samples_predict: if n_samples >= min_samples:
logger.debug(" Too few samples (%d); always predicting false" % n_samples) self.classifier = deepcopy(clf_prototype)
self.predictor = 0 self.classifier.fit(x_train, y_train)
return break
# If vast majority of observations are false, always return false.
y_train_avg = np.average(y_train)
if y_train_avg <= 1.0 - self.thr_fix:
logger.debug(" Most samples are negative (%.3f); always returning false" % y_train_avg)
self.predictor = 0
return
# If vast majority of observations are true, always return true.
if y_train_avg >= self.thr_fix:
logger.debug(" Most samples are positive (%.3f); always returning true" % y_train_avg)
self.predictor = 1
return
# If classes are too unbalanced, don't predict anything.
if y_train_avg < (1 - self.thr_balance) or y_train_avg > self.thr_balance:
logger.debug(" Classes are too unbalanced (%.3f); always returning false" % y_train_avg)
self.predictor = 0
return
# Select ML model if none is provided
if self.predictor_factory is None:
if n_samples < 30:
from sklearn.neighbors import KNeighborsClassifier
self.predictor_factory = KNeighborsClassifier(n_neighbors=n_samples)
else:
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
self.predictor_factory = make_pipeline(StandardScaler(), LogisticRegression())
# Create predictor
if callable(self.predictor_factory):
pred = self.predictor_factory()
else:
pred = deepcopy(self.predictor_factory)
# Skip cross-validation if number of samples is too small
if n_samples < self.min_samples_cv:
logger.debug(" Too few samples (%d); skipping cross validation" % n_samples)
self.predictor = pred
self.predictor.fit(x_train, y_train)
return
# Calculate cross-validation score
cv_score = np.mean(cross_val_score(pred, x_train, y_train, cv=5))
dummy_score = max(y_train_avg, 1 - y_train_avg)
cv_thr = 1. * self.thr_alpha + dummy_score * (1 - self.thr_alpha)
# If cross-validation score is too low, don't predict anything.
if cv_score < cv_thr:
logger.debug(" Score is too low (%.3f < %.3f); always returning false" % (cv_score, cv_thr))
self.predictor = 0
else:
logger.debug(" Score is acceptable (%.3f > %.3f); training classifier" % (cv_score, cv_thr))
self.predictor = pred
self.predictor.fit(x_train, y_train)
def predict_proba(self, x_test): def predict_proba(self, x_test):
if isinstance(self.predictor, int): return self.classifier.predict_proba(x_test)
y_pred = np.zeros((x_test.shape[0], 2))
y_pred[:, self.predictor] = 1.0
return y_pred
else:
return self.predictor.predict_proba(x_test)

@ -21,7 +21,8 @@ class CountingClassifier(Classifier):
self.mean = np.mean(y_train) self.mean = np.mean(y_train)
def predict_proba(self, x_test): def predict_proba(self, x_test):
return np.array([[1 - self.mean, self.mean]]) return np.array([[1 - self.mean, self.mean]
for _ in range(x_test.shape[0])])
def __repr__(self): def __repr__(self):
return "CountingClassifier(mean=%.3f)" % self.mean return "CountingClassifier(mean=%s)" % self.mean

@ -12,6 +12,7 @@ E = 0.1
def test_counting(): def test_counting():
clf = CountingClassifier() clf = CountingClassifier()
clf.fit(np.zeros((8, 25)), [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0]) clf.fit(np.zeros((8, 25)), [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0])
expected_proba = np.array([[0.375, 0.625]]) expected_proba = np.array([[0.375, 0.625],
actual_proba = clf.predict_proba(np.zeros((1, 25))) [0.375, 0.625]])
actual_proba = clf.predict_proba(np.zeros((2, 25)))
assert norm(actual_proba - expected_proba) < E assert norm(actual_proba - expected_proba) < E

@ -67,6 +67,7 @@ class PrimalSolutionComponent(Component):
x_test = VariableFeaturesExtractor().extract([instance]) x_test = VariableFeaturesExtractor().extract([instance])
var_split = Extractor.split_variables(instance) var_split = Extractor.split_variables(instance)
for category in var_split.keys(): for category in var_split.keys():
n = len(var_split[category])
for (i, (var, index)) in enumerate(var_split[category]): for (i, (var, index)) in enumerate(var_split[category]):
if var not in solution.keys(): if var not in solution.keys():
solution[var] = {} solution[var] = {}
@ -76,10 +77,12 @@ class PrimalSolutionComponent(Component):
continue continue
clf = self.classifiers[category, label] clf = self.classifiers[category, label]
if isinstance(clf, float): if isinstance(clf, float):
ws = np.array([[1-clf, clf] ws = np.array([[1 - clf, clf] for _ in range(n)])
for _ in range(len(var_split[category]))])
else: else:
ws = clf.predict_proba(x_test[category]) ws = clf.predict_proba(x_test[category])
print("clf=", clf)
print("x_test=", x_test[category])
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (n, ws.shape)
for (i, (var, index)) in enumerate(var_split[category]): for (i, (var, index)) in enumerate(var_split[category]):
if ws[i, 1] >= self.threshold: if ws[i, 1] >= self.threshold:
solution[var][index] = label solution[var][index] = label

Loading…
Cancel
Save