mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 01:48:51 -06:00
Start refactoring of classifiers
This commit is contained in:
@@ -12,7 +12,27 @@ E = 0.1
|
||||
|
||||
def test_counting():
|
||||
clf = CountingClassifier()
|
||||
clf.fit(np.zeros((8, 25)), [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
expected_proba = np.array([[0.375, 0.625], [0.375, 0.625]])
|
||||
actual_proba = clf.predict_proba(np.zeros((2, 25)))
|
||||
assert norm(actual_proba - expected_proba) < E
|
||||
n_features = 25
|
||||
x_train = np.zeros((8, n_features))
|
||||
y_train = np.array(
|
||||
[
|
||||
[True, False, False],
|
||||
[True, False, False],
|
||||
[False, True, False],
|
||||
[True, False, False],
|
||||
[False, True, False],
|
||||
[False, True, False],
|
||||
[False, True, False],
|
||||
[False, False, True],
|
||||
]
|
||||
)
|
||||
x_test = np.zeros((2, n_features))
|
||||
y_expected = np.array(
|
||||
[
|
||||
[3 / 8.0, 4 / 8.0, 1 / 8.0],
|
||||
[3 / 8.0, 4 / 8.0, 1 / 8.0],
|
||||
]
|
||||
)
|
||||
clf.fit(x_train, y_train)
|
||||
y_actual = clf.predict_proba(x_test)
|
||||
assert norm(y_actual - y_expected) < E
|
||||
|
||||
Reference in New Issue
Block a user