diff --git a/src/python/miplearn/classifiers/tests/test_threshold.py b/src/python/miplearn/classifiers/tests/test_threshold.py new file mode 100644 index 0000000..6d224b1 --- /dev/null +++ b/src/python/miplearn/classifiers/tests/test_threshold.py @@ -0,0 +1,34 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +from unittest.mock import Mock + +import numpy as np +from miplearn.classifiers import Classifier +from miplearn.classifiers.threshold import MinPrecisionThreshold + + +def test_threshold_dynamic(): + clf = Mock(spec=Classifier) + clf.predict_proba = Mock(return_value=np.array([ + [0.10, 0.90], + [0.10, 0.90], + [0.20, 0.80], + [0.30, 0.70], + ])) + x_train = np.array([0, 1, 2, 3]) + y_train = np.array([1, 1, 0, 0]) + + threshold = MinPrecisionThreshold(min_precision=1.0) + assert threshold.find(clf, x_train, y_train) == 0.90 + + threshold = MinPrecisionThreshold(min_precision=0.65) + assert threshold.find(clf, x_train, y_train) == 0.80 + + threshold = MinPrecisionThreshold(min_precision=0.50) + assert threshold.find(clf, x_train, y_train) == 0.70 + + threshold = MinPrecisionThreshold(min_precision=0.00) + assert threshold.find(clf, x_train, y_train) == 0.70 + diff --git a/src/python/miplearn/classifiers/threshold.py b/src/python/miplearn/classifiers/threshold.py new file mode 100644 index 0000000..2026506 --- /dev/null +++ b/src/python/miplearn/classifiers/threshold.py @@ -0,0 +1,45 @@ +# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization +# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. +# Released under the modified BSD license. See COPYING.md for more details. + +from abc import abstractmethod, ABC + +import numpy as np +from sklearn.metrics._ranking import _binary_clf_curve + + +class DynamicThreshold(ABC): + @abstractmethod + def find(self, clf, x_train, y_train): + """ + Given a trained binary classifier `clf` and a training data set, + returns the numerical threshold (float) satisfying some criterea. + """ + pass + + +class MinPrecisionThreshold(DynamicThreshold): + """ + The smallest possible threshold satisfying a minimum acceptable true + positive rate (also known as precision). + """ + + def __init__(self, min_precision): + self.min_precision = min_precision + + def find(self, clf, x_train, y_train): + proba = clf.predict_proba(x_train) + + assert isinstance(proba, np.ndarray), \ + "classifier should return numpy array" + assert proba.shape == (x_train.shape[0], 2), \ + "classifier should return (%d,%d)-shaped array, not %s" % ( + x_train.shape[0], 2, str(proba.shape)) + + fps, tps, thresholds = _binary_clf_curve(y_train, proba[:, 1]) + precision = tps / (tps + fps) + + for k in reversed(range(len(precision))): + if precision[k] >= self.min_precision: + return thresholds[k] + return 2.0 diff --git a/src/python/miplearn/components/primal.py b/src/python/miplearn/components/primal.py index a4339ad..af427c5 100644 --- a/src/python/miplearn/components/primal.py +++ b/src/python/miplearn/components/primal.py @@ -4,10 +4,10 @@ from copy import deepcopy -from miplearn.classifiers.adaptive import AdaptiveClassifier -from miplearn.components import classifier_evaluation_dict - from .component import Component +from ..classifiers.adaptive import AdaptiveClassifier +from ..classifiers.threshold import MinPrecisionThreshold, DynamicThreshold +from ..components import classifier_evaluation_dict from ..extractors import * logger = logging.getLogger(__name__) @@ -21,10 +21,11 @@ class PrimalSolutionComponent(Component): def __init__(self, classifier=AdaptiveClassifier(), mode="exact", - threshold=0.50): + threshold=MinPrecisionThreshold(0.95)): self.mode = mode self.classifiers = {} - self.threshold = threshold + self.thresholds = {} + self.threshold_prototype = threshold self.classifier_prototype = classifier def before_solve(self, solver, instance, model): @@ -51,6 +52,7 @@ class PrimalSolutionComponent(Component): y_avg = np.average(y_train) if y_avg < 0.001 or y_avg >= 0.999: self.classifiers[category, label] = round(y_avg) + self.thresholds[category, label] = 0.50 continue # Create a copy of classifier prototype and train it @@ -60,6 +62,12 @@ class PrimalSolutionComponent(Component): clf = deepcopy(self.classifier_prototype) clf.fit(x_train, y_train) + # Find threshold (dynamic or static) + if isinstance(self.threshold_prototype, DynamicThreshold): + self.thresholds[category, label] = self.threshold_prototype.find(clf, x_train, y_train) + else: + self.thresholds[category, label] = deepcopy(self.threshold_prototype) + self.classifiers[category, label] = clf def predict(self, instance): @@ -82,7 +90,7 @@ class PrimalSolutionComponent(Component): ws = clf.predict_proba(x_test[category]) assert ws.shape == (n, 2), "ws.shape should be (%d, 2) not %s" % (n, ws.shape) for (i, (var, index)) in enumerate(var_split[category]): - if ws[i, 1] >= self.threshold: + if ws[i, 1] >= self.thresholds[category, label]: solution[var][index] = label return solution diff --git a/src/python/miplearn/components/tests/test_primal.py b/src/python/miplearn/components/tests/test_primal.py index 8aaa50f..ee828e8 100644 --- a/src/python/miplearn/components/tests/test_primal.py +++ b/src/python/miplearn/components/tests/test_primal.py @@ -38,7 +38,8 @@ def test_evaluate(): [0., 1.], # x[2] instances[0] [1., 0.], # x[3] instances[0] ])) - comp = PrimalSolutionComponent(classifier=[clf_zero, clf_one]) + comp = PrimalSolutionComponent(classifier=[clf_zero, clf_one], + threshold=0.50) comp.fit(instances[:1]) assert comp.predict(instances[0]) == {"x": {0: 0, 1: 0,