mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 09:58:51 -06:00
Make classifiers and regressors clonable
This commit is contained in:
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional, Any, cast
|
||||
|
||||
import numpy as np
|
||||
import sklearn
|
||||
|
||||
|
||||
class Classifier(ABC):
|
||||
@@ -77,6 +78,13 @@ class Classifier(ABC):
|
||||
)
|
||||
return np.ndarray([])
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "Classifier":
|
||||
"""
|
||||
Returns an unfitted copy of this classifier with the same hyperparameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Regressor(ABC):
|
||||
"""
|
||||
@@ -136,6 +144,13 @@ class Regressor(ABC):
|
||||
assert n_inputs_x == self.n_inputs
|
||||
return np.ndarray([])
|
||||
|
||||
@abstractmethod
|
||||
def clone(self) -> "Regressor":
|
||||
"""
|
||||
Returns an unfitted copy of this regressor with the same hyperparameters.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ScikitLearnClassifier(Classifier):
|
||||
"""
|
||||
@@ -185,3 +200,8 @@ class ScikitLearnClassifier(Classifier):
|
||||
assert isinstance(sklearn_proba, np.ndarray)
|
||||
assert sklearn_proba.shape == (n_samples, 2)
|
||||
return sklearn_proba
|
||||
|
||||
def clone(self) -> "ScikitLearnClassifier":
|
||||
return ScikitLearnClassifier(
|
||||
clf=sklearn.base.clone(self.inner_clf),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user