mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Make classifiers and regressors clonable
This commit is contained in:
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional, Any, cast
|
||||
|
||||
import numpy as np
|
||||
import sklearn
|
||||
|
||||
|
||||
class Classifier(ABC):
|
||||
@@ -77,6 +78,13 @@ class Classifier(ABC):
|
||||
)
|
||||
return np.ndarray([])
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "Classifier":
|
||||
"""
|
||||
Returns an unfitted copy of this classifier with the same hyperparameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Regressor(ABC):
|
||||
"""
|
||||
@@ -136,6 +144,13 @@ class Regressor(ABC):
|
||||
assert n_inputs_x == self.n_inputs
|
||||
return np.ndarray([])
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "Regressor":
|
||||
"""
|
||||
Returns an unfitted copy of this regressor with the same hyperparameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ScikitLearnClassifier(Classifier):
|
||||
"""
|
||||
@@ -185,3 +200,8 @@ class ScikitLearnClassifier(Classifier):
|
||||
assert isinstance(sklearn_proba, np.ndarray)
|
||||
assert sklearn_proba.shape == (n_samples, 2)
|
||||
return sklearn_proba
|
||||
|
||||
def clone(self) -> "ScikitLearnClassifier":
|
||||
return ScikitLearnClassifier(
|
||||
clf=sklearn.base.clone(self.inner_clf),
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ class CandidateClassifierSpecs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Callable[[], Classifier],
|
||||
classifier: Classifier,
|
||||
min_samples: int = 0,
|
||||
) -> None:
|
||||
self.min_samples = min_samples
|
||||
@@ -64,13 +64,13 @@ class AdaptiveClassifier(Classifier):
|
||||
if candidates is None:
|
||||
candidates = {
|
||||
"knn(100)": CandidateClassifierSpecs(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
KNeighborsClassifier(n_neighbors=100)
|
||||
),
|
||||
min_samples=100,
|
||||
),
|
||||
"logistic": CandidateClassifierSpecs(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
make_pipeline(
|
||||
StandardScaler(),
|
||||
LogisticRegression(),
|
||||
@@ -79,7 +79,7 @@ class AdaptiveClassifier(Classifier):
|
||||
min_samples=30,
|
||||
),
|
||||
"counting": CandidateClassifierSpecs(
|
||||
classifier=lambda: CountingClassifier(),
|
||||
classifier=CountingClassifier(),
|
||||
),
|
||||
}
|
||||
self.candidates = candidates
|
||||
@@ -101,7 +101,7 @@ class AdaptiveClassifier(Classifier):
|
||||
for (name, specs) in self.candidates.items():
|
||||
if n_samples < specs.min_samples:
|
||||
continue
|
||||
clf = specs.classifier()
|
||||
clf = specs.classifier.clone()
|
||||
clf.fit(x_train, y_train)
|
||||
proba = clf.predict_proba(x_train)
|
||||
# FIXME: Switch to k-fold cross validation
|
||||
@@ -115,3 +115,6 @@ class AdaptiveClassifier(Classifier):
|
||||
super().predict_proba(x_test)
|
||||
assert self.classifier is not None
|
||||
return self.classifier.predict_proba(x_test)
|
||||
|
||||
def clone(self) -> "AdaptiveClassifier":
|
||||
return AdaptiveClassifier(self.candidates)
|
||||
|
||||
@@ -40,3 +40,6 @@ class CountingClassifier(Classifier):
|
||||
|
||||
def __repr__(self):
|
||||
return "CountingClassifier(mean=%s)" % self.mean
|
||||
|
||||
def clone(self) -> "CountingClassifier":
|
||||
return CountingClassifier()
|
||||
|
||||
@@ -46,9 +46,7 @@ class CrossValidatedClassifier(Classifier):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Callable[[], ScikitLearnClassifier] = (
|
||||
lambda: ScikitLearnClassifier(LogisticRegression())
|
||||
),
|
||||
classifier: ScikitLearnClassifier = ScikitLearnClassifier(LogisticRegression()),
|
||||
threshold: float = 0.75,
|
||||
constant: Optional[List[bool]] = None,
|
||||
cv: int = 5,
|
||||
@@ -60,7 +58,7 @@ class CrossValidatedClassifier(Classifier):
|
||||
constant = [True, False]
|
||||
self.n_classes = len(constant)
|
||||
self.classifier: Optional[ScikitLearnClassifier] = None
|
||||
self.classifier_factory = classifier
|
||||
self.classifier_prototype = classifier
|
||||
self.constant: List[bool] = constant
|
||||
self.threshold = threshold
|
||||
self.cv = cv
|
||||
@@ -77,7 +75,7 @@ class CrossValidatedClassifier(Classifier):
|
||||
absolute_threshold = 1.0 * self.threshold + dummy_score * (1 - self.threshold)
|
||||
|
||||
# Calculate cross validation score and decide which classifier to use
|
||||
clf = self.classifier_factory()
|
||||
clf = self.classifier_prototype.clone()
|
||||
assert clf is not None
|
||||
assert isinstance(clf, ScikitLearnClassifier), (
|
||||
f"The provided classifier callable must return a ScikitLearnClassifier. "
|
||||
@@ -123,3 +121,12 @@ class CrossValidatedClassifier(Classifier):
|
||||
super().predict_proba(x_test)
|
||||
assert self.classifier is not None
|
||||
return self.classifier.predict_proba(x_test)
|
||||
|
||||
def clone(self) -> "CrossValidatedClassifier":
|
||||
return CrossValidatedClassifier(
|
||||
classifier=self.classifier_prototype,
|
||||
threshold=self.threshold,
|
||||
constant=self.constant,
|
||||
cv=self.cv,
|
||||
scoring=self.scoring,
|
||||
)
|
||||
|
||||
@@ -49,6 +49,13 @@ class Threshold(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "Threshold":
|
||||
"""
|
||||
Returns an unfitted copy of this threshold with the same hyperparameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MinProbabilityThreshold(Threshold):
|
||||
"""
|
||||
@@ -65,6 +72,9 @@ class MinProbabilityThreshold(Threshold):
|
||||
def predict(self, x_test: np.ndarray) -> List[float]:
|
||||
return self.min_probability
|
||||
|
||||
def clone(self) -> "MinProbabilityThreshold":
|
||||
return MinProbabilityThreshold(self.min_probability)
|
||||
|
||||
|
||||
class MinPrecisionThreshold(Threshold):
|
||||
"""
|
||||
@@ -111,3 +121,8 @@ class MinPrecisionThreshold(Threshold):
|
||||
if precision[k] >= min_precision:
|
||||
return thresholds[k]
|
||||
return float("inf")
|
||||
|
||||
def clone(self) -> "MinPrecisionThreshold":
|
||||
return MinPrecisionThreshold(
|
||||
min_precision=self.min_precision,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_adaptive() -> None:
|
||||
clf = AdaptiveClassifier(
|
||||
candidates={
|
||||
"linear": CandidateClassifierSpecs(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
SVC(
|
||||
probability=True,
|
||||
random_state=42,
|
||||
@@ -23,7 +23,7 @@ def test_adaptive() -> None:
|
||||
)
|
||||
),
|
||||
"poly": CandidateClassifierSpecs(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
SVC(
|
||||
probability=True,
|
||||
kernel="poly",
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_cv() -> None:
|
||||
# Support vector machines with linear kernels do not perform well on this
|
||||
# data set, so predictor should return the given constant.
|
||||
clf = CrossValidatedClassifier(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
SVC(
|
||||
probability=True,
|
||||
random_state=42,
|
||||
@@ -41,7 +41,7 @@ def test_cv() -> None:
|
||||
# Support vector machines with quadratic kernels perform almost perfectly
|
||||
# on this data set, so predictor should return their prediction.
|
||||
clf = CrossValidatedClassifier(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
classifier=ScikitLearnClassifier(
|
||||
SVC(
|
||||
probability=True,
|
||||
kernel="poly",
|
||||
|
||||
Reference in New Issue
Block a user