Make classifiers and regressors clonable

This commit is contained in:
2021-04-01 07:41:59 -05:00
parent ac29b5213f
commit 820a6256c2
7 changed files with 62 additions and 14 deletions

View File

@@ -34,7 +34,7 @@ class CandidateClassifierSpecs:
def __init__(
self,
classifier: Callable[[], Classifier],
classifier: Classifier,
min_samples: int = 0,
) -> None:
self.min_samples = min_samples
@@ -64,13 +64,13 @@ class AdaptiveClassifier(Classifier):
if candidates is None:
candidates = {
"knn(100)": CandidateClassifierSpecs(
classifier=lambda: ScikitLearnClassifier(
classifier=ScikitLearnClassifier(
KNeighborsClassifier(n_neighbors=100)
),
min_samples=100,
),
"logistic": CandidateClassifierSpecs(
classifier=lambda: ScikitLearnClassifier(
classifier=ScikitLearnClassifier(
make_pipeline(
StandardScaler(),
LogisticRegression(),
@@ -79,7 +79,7 @@ class AdaptiveClassifier(Classifier):
min_samples=30,
),
"counting": CandidateClassifierSpecs(
classifier=lambda: CountingClassifier(),
classifier=CountingClassifier(),
),
}
self.candidates = candidates
@@ -101,7 +101,7 @@ class AdaptiveClassifier(Classifier):
for (name, specs) in self.candidates.items():
if n_samples < specs.min_samples:
continue
clf = specs.classifier()
clf = specs.classifier.clone()
clf.fit(x_train, y_train)
proba = clf.predict_proba(x_train)
# FIXME: Switch to k-fold cross validation
@@ -115,3 +115,6 @@ class AdaptiveClassifier(Classifier):
super().predict_proba(x_test)
assert self.classifier is not None
return self.classifier.predict_proba(x_test)
def clone(self) -> "AdaptiveClassifier":
return AdaptiveClassifier(self.candidates)