mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Start refactoring of classifiers
This commit is contained in:
@@ -152,25 +152,18 @@ dtype: float64
|
||||
|
||||
### Using customized ML classifiers and regressors
|
||||
|
||||
By default, given a training set of instantes, MIPLearn trains a fixed set of ML classifiers and regressors, then
|
||||
selects the best one based on cross-validation performance. Alternatively, the user may specify which ML model a component
|
||||
should use through the `classifier` or `regressor` contructor parameters. The provided classifiers and regressors must
|
||||
follow the sklearn API. In particular, classifiers must provide the methods `fit`, `predict_proba` and `predict`,
|
||||
while regressors must provide the methods `fit` and `predict`
|
||||
By default, given a training set of instantes, MIPLearn trains a fixed set of ML classifiers and regressors, then selects the best one based on cross-validation performance. Alternatively, the user may specify which ML model a component should use through the `classifier` or `regressor` contructor parameters. Scikit-learn classifiers and regressors are currently supported. A future version of the package will add compatibility with Keras models.
|
||||
|
||||
!!! danger
|
||||
MIPLearn must be able to generate a copy of any custom ML classifiers and regressors through
|
||||
the standard `copy.deepcopy` method. This currently makes it incompatible with Keras and TensorFlow
|
||||
predictors. This is a known limitation, which will be addressed in a future version.
|
||||
|
||||
The example below shows how to construct a `PrimalSolutionComponent` which internally uses
|
||||
sklearn's `KNeighborsClassifiers`. Any other sklearn classifier or pipeline can be used.
|
||||
The example below shows how to construct a `PrimalSolutionComponent` which internally uses scikit-learn's `KNeighborsClassifiers`. Any other scikit-learn classifier or pipeline can be used. The classifier needs to be provided as a lambda function because the component may need to create multiple copies of it. It needs to be wrapped in `ScikitLearnClassifier` to ensure that all the proper data transformations are applied.
|
||||
|
||||
```python
|
||||
from miplearn import PrimalSolutionComponent
|
||||
from miplearn import PrimalSolutionComponent, ScikitLearnClassifier
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
|
||||
comp = PrimalSolutionComponent(classifier=KNeighborsClassifier(n_neighbors=5))
|
||||
comp = PrimalSolutionComponent(
|
||||
classifier=lambda: ScikitLearnClassifier(
|
||||
KNeighborsClassifier(n_neighbors=5),
|
||||
),
|
||||
)
|
||||
comp.fit(train_instances)
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user