Make classifiers and regressors clonable

This commit is contained in:
2021-04-01 07:41:59 -05:00
parent ac29b5213f
commit 820a6256c2
7 changed files with 62 additions and 14 deletions

View File

@@ -49,6 +49,13 @@ class Threshold(ABC):
"""
pass
@abstractmethod
def clone(self) -> "Threshold":
"""
Returns an unfitted copy of this threshold with the same hyperparameters.
"""
pass
class MinProbabilityThreshold(Threshold):
"""
@@ -65,6 +72,9 @@ class MinProbabilityThreshold(Threshold):
def predict(self, x_test: np.ndarray) -> List[float]:
return self.min_probability
def clone(self) -> "MinProbabilityThreshold":
return MinProbabilityThreshold(self.min_probability)
class MinPrecisionThreshold(Threshold):
"""
@@ -111,3 +121,8 @@ class MinPrecisionThreshold(Threshold):
if precision[k] >= min_precision:
return thresholds[k]
return float("inf")
def clone(self) -> "MinPrecisionThreshold":
return MinPrecisionThreshold(
min_precision=self.min_precision,
)