Make classifiers and regressors clonable

This commit is contained in:
2021-04-01 07:41:59 -05:00
parent ac29b5213f
commit 820a6256c2
7 changed files with 62 additions and 14 deletions

View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from typing import Optional, Any, cast
import numpy as np
import sklearn
class Classifier(ABC):
@@ -77,6 +78,13 @@ class Classifier(ABC):
)
return np.ndarray([])
@abstractmethod
def clone(self) -> "Classifier":
"""
Returns an unfitted copy of this classifier with the same hyperparameters.
"""
pass
class Regressor(ABC):
"""
@@ -136,6 +144,13 @@ class Regressor(ABC):
assert n_inputs_x == self.n_inputs
return np.ndarray([])
@abstractmethod
def clone(self) -> "Regressor":
"""
Returns an unfitted copy of this regressor with the same hyperparameters.
"""
pass
class ScikitLearnClassifier(Classifier):
"""
@@ -185,3 +200,8 @@ class ScikitLearnClassifier(Classifier):
assert isinstance(sklearn_proba, np.ndarray)
assert sklearn_proba.shape == (n_samples, 2)
return sklearn_proba
def clone(self) -> "ScikitLearnClassifier":
return ScikitLearnClassifier(
clf=sklearn.base.clone(self.inner_clf),
)