Refactor thresholds

This commit is contained in:
2021-01-25 09:52:49 -06:00
parent 4da561a6a8
commit f68cc5bd59
4 changed files with 82 additions and 41 deletions

View File

@@ -26,13 +26,17 @@ def test_threshold_dynamic():
y_train = np.array([1, 1, 0, 0])
threshold = MinPrecisionThreshold(min_precision=1.0)
assert threshold.find(clf, x_train, y_train) == 0.90
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.90
threshold = MinPrecisionThreshold(min_precision=0.65)
assert threshold.find(clf, x_train, y_train) == 0.80
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.80
threshold = MinPrecisionThreshold(min_precision=0.50)
assert threshold.find(clf, x_train, y_train) == 0.70
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.70
threshold = MinPrecisionThreshold(min_precision=0.00)
assert threshold.find(clf, x_train, y_train) == 0.70
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.70