Refactor PrimalSolutionComponent

This commit is contained in:
2021-01-25 14:54:58 -06:00
parent f68cc5bd59
commit 3ab3bb3c1f
9 changed files with 501 additions and 233 deletions

View File

@@ -142,8 +142,11 @@ class ScikitLearnClassifier(Classifier):
def fit(self, x_train: np.ndarray, y_train: np.ndarray) -> None:
super().fit(x_train, y_train)
(n_samples, n_classes) = x_train.shape
assert n_classes == 2, "scikit-learn classifiers must have exactly two classes"
(n_samples, n_classes) = y_train.shape
assert n_classes == 2, (
f"Scikit-learn classifiers must have exactly two classes. "
f"{n_classes} classes were provided instead."
)
self.inner_clf.fit(x_train, y_train[:, 1])
def predict_proba(self, x_test: np.ndarray) -> np.ndarray: