mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 09:58:51 -06:00
Refactor PrimalSolutionComponent
This commit is contained in:
@@ -142,8 +142,11 @@ class ScikitLearnClassifier(Classifier):
|
||||
|
||||
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
|
||||
super().fit(x_train, y_train)
|
||||
(n_samples, n_classes) = x_train.shape
|
||||
assert n_classes == 2, "scikit-learn classifiers must have exactly two classes"
|
||||
(n_samples, n_classes) = y_train.shape
|
||||
assert n_classes == 2, (
|
||||
f"Scikit-learn classifiers must have exactly two classes. "
|
||||
f"{n_classes} classes were provided instead."
|
||||
)
|
||||
self.inner_clf.fit(x_train, y_train[:, 1])
|
||||
|
||||
def predict_proba(self, x_test: np.ndarray) -> np.ndarray:
|
||||
|
||||
Reference in New Issue
Block a user