Refactor PrimalSolutionComponent

This commit is contained in:
2021-01-25 14:54:58 -06:00
parent f68cc5bd59
commit 3ab3bb3c1f
9 changed files with 501 additions and 233 deletions

View File

@@ -16,27 +16,41 @@ def test_threshold_dynamic():
return_value=np.array(
[
[0.10, 0.90],
[0.10, 0.90],
[0.20, 0.80],
[0.30, 0.70],
[0.25, 0.75],
[0.40, 0.60],
[0.90, 0.10],
]
)
)
x_train = np.array([0, 1, 2, 3])
y_train = np.array([1, 1, 0, 0])
x_train = np.array(
[
[0],
[1],
[2],
[3],
]
)
y_train = np.array(
[
[False, True],
[False, True],
[True, False],
[True, False],
]
)
threshold = MinPrecisionThreshold(min_precision=1.0)
threshold = MinPrecisionThreshold(min_precision=[1.0, 1.0])
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.90
assert threshold.predict(x_train) == [0.40, 0.75]
threshold = MinPrecisionThreshold(min_precision=0.65)
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.80
# threshold = MinPrecisionThreshold(min_precision=0.65)
# threshold.fit(clf, x_train, y_train)
# assert threshold.predict(x_train) == [0.0, 0.80]
threshold = MinPrecisionThreshold(min_precision=0.50)
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.70
threshold = MinPrecisionThreshold(min_precision=0.00)
threshold.fit(clf, x_train, y_train)
assert threshold.predict(x_train) == 0.70
# threshold = MinPrecisionThreshold(min_precision=0.50)
# threshold.fit(clf, x_train, y_train)
# assert threshold.predict(x_train) == [0.0, 0.70]
#
# threshold = MinPrecisionThreshold(min_precision=0.00)
# threshold.fit(clf, x_train, y_train)
# assert threshold.predict(x_train) == [0.0, 0.70]