mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Make AdaptiveClassifier pick best classifier based on score
This commit is contained in:
@@ -7,7 +7,9 @@ from copy import deepcopy
|
|||||||
|
|
||||||
from miplearn.classifiers import Classifier
|
from miplearn.classifiers import Classifier
|
||||||
from miplearn.classifiers.counting import CountingClassifier
|
from miplearn.classifiers.counting import CountingClassifier
|
||||||
|
from miplearn.classifiers.evaluator import ClassifierEvaluator
|
||||||
from sklearn.linear_model import LogisticRegression
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
from sklearn.pipeline import make_pipeline
|
from sklearn.pipeline import make_pipeline
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
@@ -17,45 +19,48 @@ logger = logging.getLogger(__name__)
|
|||||||
class AdaptiveClassifier(Classifier):
|
class AdaptiveClassifier(Classifier):
|
||||||
"""
|
"""
|
||||||
A meta-classifier which dynamically selects what actual classifier to use
|
A meta-classifier which dynamically selects what actual classifier to use
|
||||||
based on the number of samples in the training data.
|
based on its cross-validation score on a particular training data set.
|
||||||
|
|
||||||
By default, uses CountingClassifier for less than 30 samples and
|
|
||||||
LogisticRegression (with standard scaling) for 30 or more samples.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, classifiers=None):
|
def __init__(self,
|
||||||
|
candidates=None,
|
||||||
|
evaluator=ClassifierEvaluator()):
|
||||||
"""
|
"""
|
||||||
Initializes the classifier.
|
Initializes the meta-classifier.
|
||||||
|
|
||||||
The `classifiers` argument must be a list of tuples where the second element
|
|
||||||
of the tuple is the classifier and the first element is the number of
|
|
||||||
samples required. For example, if `classifiers` is set to
|
|
||||||
```
|
|
||||||
[(100, ClassifierA()),
|
|
||||||
(50, ClassifierB()),
|
|
||||||
(0, ClassifierC())]
|
|
||||||
``` then ClassifierA will be used if n_samples >= 100, ClassifierB will
|
|
||||||
be used if 100 > n_samples >= 50 and ClassifierC will be used if
|
|
||||||
50 > n_samples. The list must be ordered in (strictly) decreasing order.
|
|
||||||
"""
|
"""
|
||||||
if classifiers is None:
|
if candidates is None:
|
||||||
classifiers = [
|
candidates = {
|
||||||
(30, make_pipeline(StandardScaler(), LogisticRegression())),
|
"knn(100)": {
|
||||||
(0, CountingClassifier())
|
"classifier": KNeighborsClassifier(n_neighbors=100),
|
||||||
]
|
"min samples": 100,
|
||||||
self.available_classifiers = classifiers
|
},
|
||||||
|
"logistic": {
|
||||||
|
"classifier": make_pipeline(StandardScaler(),
|
||||||
|
LogisticRegression()),
|
||||||
|
"min samples": 30,
|
||||||
|
},
|
||||||
|
"counting": {
|
||||||
|
"classifier": CountingClassifier(),
|
||||||
|
"min samples": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.candidates = candidates
|
||||||
|
self.evaluator = evaluator
|
||||||
self.classifier = None
|
self.classifier = None
|
||||||
|
|
||||||
def fit(self, x_train, y_train):
|
def fit(self, x_train, y_train):
|
||||||
|
best_clf = None
|
||||||
|
best_score = -float("inf")
|
||||||
n_samples = x_train.shape[0]
|
n_samples = x_train.shape[0]
|
||||||
|
for clf_dict in self.candidates.values():
|
||||||
for (min_samples, clf_prototype) in self.available_classifiers:
|
if n_samples < clf_dict["min samples"]:
|
||||||
if n_samples >= min_samples:
|
continue
|
||||||
self.classifier = deepcopy(clf_prototype)
|
clf = deepcopy(clf_dict["classifier"])
|
||||||
self.classifier.fit(x_train, y_train)
|
clf.fit(x_train, y_train)
|
||||||
break
|
score = self.evaluator.evaluate(clf, x_train, y_train)
|
||||||
|
if score > best_score:
|
||||||
|
best_clf, best_score = clf, score
|
||||||
|
self.classifier = best_clf
|
||||||
|
|
||||||
def predict_proba(self, x_test):
|
def predict_proba(self, x_test):
|
||||||
return self.classifier.predict_proba(x_test)
|
return self.classifier.predict_proba(x_test)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
15
src/python/miplearn/classifiers/evaluator.py
Normal file
15
src/python/miplearn/classifiers/evaluator.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
from sklearn.metrics import roc_auc_score
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierEvaluator:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def evaluate(self, clf, x_train, y_train):
|
||||||
|
# FIXME: use cross-validation
|
||||||
|
proba = clf.predict_proba(x_train)
|
||||||
|
return roc_auc_score(y_train, proba[:, 1])
|
||||||
20
src/python/miplearn/classifiers/tests/test_evaluator.py
Normal file
20
src/python/miplearn/classifiers/tests/test_evaluator.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||||||
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||||
|
# Released under the modified BSD license. See COPYING.md for more details.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from miplearn.classifiers.evaluator import ClassifierEvaluator
|
||||||
|
from sklearn.neighbors import KNeighborsClassifier
|
||||||
|
|
||||||
|
|
||||||
|
def test_evaluator():
|
||||||
|
clf_a = KNeighborsClassifier(n_neighbors=1)
|
||||||
|
clf_b = KNeighborsClassifier(n_neighbors=2)
|
||||||
|
x_train = np.array([[0, 0], [1, 0]])
|
||||||
|
y_train = np.array([0, 1])
|
||||||
|
clf_a.fit(x_train, y_train)
|
||||||
|
clf_b.fit(x_train, y_train)
|
||||||
|
ev = ClassifierEvaluator()
|
||||||
|
assert ev.evaluate(clf_a, x_train, y_train) == 1.0
|
||||||
|
assert ev.evaluate(clf_b, x_train, y_train) == 0.5
|
||||||
|
|
||||||
Reference in New Issue
Block a user