mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Refactor thresholds
This commit is contained in:
@@ -26,13 +26,17 @@ def test_threshold_dynamic():
|
||||
y_train = np.array([1, 1, 0, 0])
|
||||
|
||||
threshold = MinPrecisionThreshold(min_precision=1.0)
|
||||
assert threshold.find(clf, x_train, y_train) == 0.90
|
||||
threshold.fit(clf, x_train, y_train)
|
||||
assert threshold.predict(x_train) == 0.90
|
||||
|
||||
threshold = MinPrecisionThreshold(min_precision=0.65)
|
||||
assert threshold.find(clf, x_train, y_train) == 0.80
|
||||
threshold.fit(clf, x_train, y_train)
|
||||
assert threshold.predict(x_train) == 0.80
|
||||
|
||||
threshold = MinPrecisionThreshold(min_precision=0.50)
|
||||
assert threshold.find(clf, x_train, y_train) == 0.70
|
||||
threshold.fit(clf, x_train, y_train)
|
||||
assert threshold.predict(x_train) == 0.70
|
||||
|
||||
threshold = MinPrecisionThreshold(min_precision=0.00)
|
||||
assert threshold.find(clf, x_train, y_train) == 0.70
|
||||
threshold.fit(clf, x_train, y_train)
|
||||
assert threshold.predict(x_train) == 0.70
|
||||
|
||||
Reference in New Issue
Block a user