|
|
|
@ -80,8 +80,6 @@ class PrimalSolutionComponent(Component):
|
|
|
|
|
ws = np.array([[1 - clf, clf] for _ in range(n)])
|
|
|
|
|
else:
|
|
|
|
|
ws = clf.predict_proba(x_test[category])
|
|
|
|
|
print("clf=", clf)
|
|
|
|
|
print("x_test=", x_test[category])
|
|
|
|
|
assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (n, ws.shape)
|
|
|
|
|
for (i, (var, index)) in enumerate(var_split[category]):
|
|
|
|
|
if ws[i, 1] >= self.threshold:
|
|
|
|
|