From 820a6256c2ac1d531527a1545ca2eb2af76fc3e4 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Thu, 1 Apr 2021 07:41:59 -0500 Subject: [PATCH] Make classifiers and regressors clonable --- miplearn/classifiers/__init__.py | 20 ++++++++++++++++++++ miplearn/classifiers/adaptive.py | 13 ++++++++----- miplearn/classifiers/counting.py | 3 +++ miplearn/classifiers/cv.py | 17 ++++++++++++----- miplearn/classifiers/threshold.py | 15 +++++++++++++++ tests/classifiers/test_adaptive.py | 4 ++-- tests/classifiers/test_cv.py | 4 ++-- 7 files changed, 62 insertions(+), 14 deletions(-) diff --git a/miplearn/classifiers/__init__.py b/miplearn/classifiers/__init__.py index cfa8f64..b3172b7 100644 --- a/miplearn/classifiers/__init__.py +++ b/miplearn/classifiers/__init__.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import Optional, Any, cast import numpy as np +import sklearn class Classifier(ABC): @@ -77,6 +78,13 @@ class Classifier(ABC): ) return np.ndarray([]) + @abstractmethod + def clone(self) -> "Classifier": + """ + Returns an unfitted copy of this classifier with the same hyperparameters. + """ + pass + class Regressor(ABC): """ @@ -136,6 +144,13 @@ class Regressor(ABC): assert n_inputs_x == self.n_inputs return np.ndarray([]) + @abstractmethod + def clone(self) -> "Regressor": + """ + Returns an unfitted copy of this regressor with the same hyperparameters. + """ + pass + class ScikitLearnClassifier(Classifier): """ @@ -185,3 +200,8 @@ class ScikitLearnClassifier(Classifier): assert isinstance(sklearn_proba, np.ndarray) assert sklearn_proba.shape == (n_samples, 2) return sklearn_proba + + def clone(self) -> "ScikitLearnClassifier": + return ScikitLearnClassifier( + clf=sklearn.base.clone(self.inner_clf), + ) diff --git a/miplearn/classifiers/adaptive.py b/miplearn/classifiers/adaptive.py index 94d753d..d12554e 100644 --- a/miplearn/classifiers/adaptive.py +++ b/miplearn/classifiers/adaptive.py @@ -34,7 +34,7 @@ class CandidateClassifierSpecs: def __init__( self, - classifier: Callable[[], Classifier], + classifier: Classifier, min_samples: int = 0, ) -> None: self.min_samples = min_samples @@ -64,13 +64,13 @@ class AdaptiveClassifier(Classifier): if candidates is None: candidates = { "knn(100)": CandidateClassifierSpecs( - classifier=lambda: ScikitLearnClassifier( + classifier=ScikitLearnClassifier( KNeighborsClassifier(n_neighbors=100) ), min_samples=100, ), "logistic": CandidateClassifierSpecs( - classifier=lambda: ScikitLearnClassifier( + classifier=ScikitLearnClassifier( make_pipeline( StandardScaler(), LogisticRegression(), @@ -79,7 +79,7 @@ class AdaptiveClassifier(Classifier): min_samples=30, ), "counting": CandidateClassifierSpecs( - classifier=lambda: CountingClassifier(), + classifier=CountingClassifier(), ), } self.candidates = candidates @@ -101,7 +101,7 @@ class AdaptiveClassifier(Classifier): for (name, specs) in self.candidates.items(): if n_samples < specs.min_samples: continue - clf = specs.classifier() + clf = specs.classifier.clone() clf.fit(x_train, y_train) proba = clf.predict_proba(x_train) # FIXME: Switch to k-fold cross validation @@ -115,3 +115,6 @@ class AdaptiveClassifier(Classifier): super().predict_proba(x_test) assert self.classifier is not None return self.classifier.predict_proba(x_test) + + def clone(self) -> "AdaptiveClassifier": + return AdaptiveClassifier(self.candidates) diff --git a/miplearn/classifiers/counting.py b/miplearn/classifiers/counting.py index 226013b..c793245 100644 --- a/miplearn/classifiers/counting.py +++ b/miplearn/classifiers/counting.py @@ -40,3 +40,6 @@ class CountingClassifier(Classifier): def __repr__(self): return "CountingClassifier(mean=%s)" % self.mean + + def clone(self) -> "CountingClassifier": + return CountingClassifier() diff --git a/miplearn/classifiers/cv.py b/miplearn/classifiers/cv.py index 2743457..79cb822 100644 --- a/miplearn/classifiers/cv.py +++ b/miplearn/classifiers/cv.py @@ -46,9 +46,7 @@ class CrossValidatedClassifier(Classifier): def __init__( self, - classifier: Callable[[], ScikitLearnClassifier] = ( - lambda: ScikitLearnClassifier(LogisticRegression()) - ), + classifier: ScikitLearnClassifier = ScikitLearnClassifier(LogisticRegression()), threshold: float = 0.75, constant: Optional[List[bool]] = None, cv: int = 5, @@ -60,7 +58,7 @@ class CrossValidatedClassifier(Classifier): constant = [True, False] self.n_classes = len(constant) self.classifier: Optional[ScikitLearnClassifier] = None - self.classifier_factory = classifier + self.classifier_prototype = classifier self.constant: List[bool] = constant self.threshold = threshold self.cv = cv @@ -77,7 +75,7 @@ class CrossValidatedClassifier(Classifier): absolute_threshold = 1.0 * self.threshold + dummy_score * (1 - self.threshold) # Calculate cross validation score and decide which classifier to use - clf = self.classifier_factory() + clf = self.classifier_prototype.clone() assert clf is not None assert isinstance(clf, ScikitLearnClassifier), ( f"The provided classifier callable must return a ScikitLearnClassifier. " @@ -123,3 +121,12 @@ class CrossValidatedClassifier(Classifier): super().predict_proba(x_test) assert self.classifier is not None return self.classifier.predict_proba(x_test) + + def clone(self) -> "CrossValidatedClassifier": + return CrossValidatedClassifier( + classifier=self.classifier_prototype, + threshold=self.threshold, + constant=self.constant, + cv=self.cv, + scoring=self.scoring, + ) diff --git a/miplearn/classifiers/threshold.py b/miplearn/classifiers/threshold.py index aa0fa06..9df6a4d 100644 --- a/miplearn/classifiers/threshold.py +++ b/miplearn/classifiers/threshold.py @@ -49,6 +49,13 @@ class Threshold(ABC): """ pass + @abstractmethod + def clone(self) -> "Threshold": + """ + Returns an unfitted copy of this threshold with the same hyperparameters. + """ + pass + class MinProbabilityThreshold(Threshold): """ @@ -65,6 +72,9 @@ class MinProbabilityThreshold(Threshold): def predict(self, x_test: np.ndarray) -> List[float]: return self.min_probability + def clone(self) -> "MinProbabilityThreshold": + return MinProbabilityThreshold(self.min_probability) + class MinPrecisionThreshold(Threshold): """ @@ -111,3 +121,8 @@ class MinPrecisionThreshold(Threshold): if precision[k] >= min_precision: return thresholds[k] return float("inf") + + def clone(self) -> "MinPrecisionThreshold": + return MinPrecisionThreshold( + min_precision=self.min_precision, + ) diff --git a/tests/classifiers/test_adaptive.py b/tests/classifiers/test_adaptive.py index 959b196..7ac411f 100644 --- a/tests/classifiers/test_adaptive.py +++ b/tests/classifiers/test_adaptive.py @@ -15,7 +15,7 @@ def test_adaptive() -> None: clf = AdaptiveClassifier( candidates={ "linear": CandidateClassifierSpecs( - classifier=lambda: ScikitLearnClassifier( + classifier=ScikitLearnClassifier( SVC( probability=True, random_state=42, @@ -23,7 +23,7 @@ def test_adaptive() -> None: ) ), "poly": CandidateClassifierSpecs( - classifier=lambda: ScikitLearnClassifier( + classifier=ScikitLearnClassifier( SVC( probability=True, kernel="poly", diff --git a/tests/classifiers/test_cv.py b/tests/classifiers/test_cv.py index 618b3cf..b15f138 100644 --- a/tests/classifiers/test_cv.py +++ b/tests/classifiers/test_cv.py @@ -20,7 +20,7 @@ def test_cv() -> None: # Support vector machines with linear kernels do not perform well on this # data set, so predictor should return the given constant. clf = CrossValidatedClassifier( - classifier=lambda: ScikitLearnClassifier( + classifier=ScikitLearnClassifier( SVC( probability=True, random_state=42, @@ -41,7 +41,7 @@ def test_cv() -> None: # Support vector machines with quadratic kernels perform almost perfectly # on this data set, so predictor should return their prediction. clf = CrossValidatedClassifier( - classifier=lambda: ScikitLearnClassifier( + classifier=ScikitLearnClassifier( SVC( probability=True, kernel="poly",