mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 09:58:51 -06:00
Make classifiers and regressors clonable
This commit is contained in:
@@ -34,7 +34,7 @@ class CandidateClassifierSpecs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Callable[[], Classifier],
|
||||
classifier: Classifier,
|
||||
min_samples: int = 0,
|
||||
) -> None:
|
||||
self.min_samples = min_samples
|
||||
@@ -64,13 +64,13 @@ class AdaptiveClassifier(Classifier):
|
||||
if candidates is None:
|
||||
candidates = {
|
||||
"knn(100)": CandidateClassifierSpecs(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
KNeighborsClassifier(n_neighbors=100)
|
||||
),
|
||||
min_samples=100,
|
||||
),
|
||||
"logistic": CandidateClassifierSpecs(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
make_pipeline(
|
||||
StandardScaler(),
|
||||
LogisticRegression(),
|
||||
@@ -79,7 +79,7 @@ class AdaptiveClassifier(Classifier):
|
||||
min_samples=30,
|
||||
),
|
||||
"counting": CandidateClassifierSpecs(
|
||||
classifier=lambda: CountingClassifier(),
|
||||
classifier=CountingClassifier(),
|
||||
),
|
||||
}
|
||||
self.candidates = candidates
|
||||
@@ -101,7 +101,7 @@ class AdaptiveClassifier(Classifier):
|
||||
for (name, specs) in self.candidates.items():
|
||||
if n_samples < specs.min_samples:
|
||||
continue
|
||||
clf = specs.classifier()
|
||||
clf = specs.classifier.clone()
|
||||
clf.fit(x_train, y_train)
|
||||
proba = clf.predict_proba(x_train)
|
||||
# FIXME: Switch to k-fold cross validation
|
||||
@@ -115,3 +115,6 @@ class AdaptiveClassifier(Classifier):
|
||||
super().predict_proba(x_test)
|
||||
assert self.classifier is not None
|
||||
return self.classifier.predict_proba(x_test)
|
||||
|
||||
def clone(self) -> "AdaptiveClassifier":
|
||||
return AdaptiveClassifier(self.candidates)
|
||||
|
||||
Reference in New Issue
Block a user