You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.6 KiB
57 lines
1.6 KiB
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
from unittest.mock import Mock
|
|
|
|
import numpy as np
|
|
|
|
from miplearn.classifiers import Classifier
|
|
from miplearn.classifiers.threshold import MinPrecisionThreshold
|
|
|
|
|
|
def test_threshold_dynamic():
|
|
clf = Mock(spec=Classifier)
|
|
clf.predict_proba = Mock(
|
|
return_value=np.array(
|
|
[
|
|
[0.10, 0.90],
|
|
[0.25, 0.75],
|
|
[0.40, 0.60],
|
|
[0.90, 0.10],
|
|
]
|
|
)
|
|
)
|
|
x_train = np.array(
|
|
[
|
|
[0],
|
|
[1],
|
|
[2],
|
|
[3],
|
|
]
|
|
)
|
|
y_train = np.array(
|
|
[
|
|
[False, True],
|
|
[False, True],
|
|
[True, False],
|
|
[True, False],
|
|
]
|
|
)
|
|
|
|
threshold = MinPrecisionThreshold(min_precision=[1.0, 1.0])
|
|
threshold.fit(clf, x_train, y_train)
|
|
assert threshold.predict(x_train) == [0.40, 0.75]
|
|
|
|
# threshold = MinPrecisionThreshold(min_precision=0.65)
|
|
# threshold.fit(clf, x_train, y_train)
|
|
# assert threshold.predict(x_train) == [0.0, 0.80]
|
|
|
|
# threshold = MinPrecisionThreshold(min_precision=0.50)
|
|
# threshold.fit(clf, x_train, y_train)
|
|
# assert threshold.predict(x_train) == [0.0, 0.70]
|
|
#
|
|
# threshold = MinPrecisionThreshold(min_precision=0.00)
|
|
# threshold.fit(clf, x_train, y_train)
|
|
# assert threshold.predict(x_train) == [0.0, 0.70]
|