mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Make classifiers and regressors clonable
This commit is contained in:
@@ -49,6 +49,13 @@ class Threshold(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "Threshold":
|
||||
"""
|
||||
Returns an unfitted copy of this threshold with the same hyperparameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MinProbabilityThreshold(Threshold):
|
||||
"""
|
||||
@@ -65,6 +72,9 @@ class MinProbabilityThreshold(Threshold):
|
||||
def predict(self, x_test: np.ndarray) -> List[float]:
|
||||
return self.min_probability
|
||||
|
||||
def clone(self) -> "MinProbabilityThreshold":
|
||||
return MinProbabilityThreshold(self.min_probability)
|
||||
|
||||
|
||||
class MinPrecisionThreshold(Threshold):
|
||||
"""
|
||||
@@ -111,3 +121,8 @@ class MinPrecisionThreshold(Threshold):
|
||||
if precision[k] >= min_precision:
|
||||
return thresholds[k]
|
||||
return float("inf")
|
||||
|
||||
def clone(self) -> "MinPrecisionThreshold":
|
||||
return MinPrecisionThreshold(
|
||||
min_precision=self.min_precision,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user