Improve stable set generator

This commit is contained in:
2020-01-26 08:25:26 -06:00
parent 3644c59101
commit 3c9b1e2f44
6 changed files with 135 additions and 54 deletions

View File

@@ -18,6 +18,7 @@ class WarmStartPredictor(ABC):
def fit(self, x_train, y_train):
assert isinstance(x_train, np.ndarray)
assert isinstance(y_train, np.ndarray)
y_train = y_train.astype(int)
assert y_train.shape[0] == x_train.shape[0]
assert y_train.shape[1] == 2
for i in [0,1]: