Add type annotations to components

This commit is contained in:
2021-01-21 15:54:23 -06:00
parent a98a783969
commit fc0835e694
12 changed files with 122 additions and 76 deletions

View File

@@ -4,6 +4,7 @@
import logging
from copy import deepcopy
from typing import Any, Dict
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
@@ -25,9 +26,9 @@ class AdaptiveClassifier(Classifier):
def __init__(
self,
candidates=None,
evaluator=ClassifierEvaluator(),
):
candidates: Dict[str, Any] = None,
evaluator: ClassifierEvaluator = ClassifierEvaluator(),
) -> None:
"""
Initializes the meta-classifier.
"""

View File

@@ -15,7 +15,7 @@ class CountingClassifier(Classifier):
counts how many times each label appeared, hence the name.
"""
def __init__(self):
def __init__(self) -> None:
self.mean = None
def fit(self, x_train, y_train):

View File

@@ -6,7 +6,7 @@ from sklearn.metrics import roc_auc_score
class ClassifierEvaluator:
def __init__(self):
def __init__(self) -> None:
pass
def evaluate(self, clf, x_train, y_train):

View File

@@ -7,10 +7,17 @@ from abc import abstractmethod, ABC
import numpy as np
from sklearn.metrics._ranking import _binary_clf_curve
from miplearn.classifiers import Classifier
class DynamicThreshold(ABC):
@abstractmethod
def find(self, clf, x_train, y_train):
def find(
self,
clf: Classifier,
x_train: np.ndarray,
y_train: np.ndarray,
) -> float:
"""
Given a trained binary classifier `clf` and a training data set,
returns the numerical threshold (float) satisfying some criterea.
@@ -24,7 +31,7 @@ class MinPrecisionThreshold(DynamicThreshold):
positive rate (also known as precision).
"""
def __init__(self, min_precision):
def __init__(self, min_precision: float) -> None:
self.min_precision = min_precision
def find(self, clf, x_train, y_train):