Refactor thresholds

This commit is contained in:
2021-01-25 09:52:49 -06:00
parent 4da561a6a8
commit f68cc5bd59
4 changed files with 82 additions and 41 deletions

View File

@@ -11,7 +11,7 @@ from tqdm.auto import tqdm
from miplearn.classifiers import Classifier
from miplearn.classifiers.adaptive import AdaptiveClassifier
from miplearn.classifiers.threshold import MinPrecisionThreshold, DynamicThreshold
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.component import Component
from miplearn.extractors import VariableFeaturesExtractor, SolutionExtractor, Extractor
@@ -28,11 +28,11 @@ class PrimalSolutionComponent(Component):
self,
classifier: Classifier = AdaptiveClassifier(),
mode: str = "exact",
threshold: Union[float, DynamicThreshold] = MinPrecisionThreshold(0.98),
threshold: Union[float, Threshold] = MinPrecisionThreshold(0.98),
) -> None:
self.mode = mode
self.classifiers: Dict[Any, Classifier] = {}
self.thresholds: Dict[Any, Union[float, DynamicThreshold]] = {}
self.thresholds: Dict[Any, Union[float, Threshold]] = {}
self.threshold_prototype = threshold
self.classifier_prototype = classifier
@@ -89,8 +89,8 @@ class PrimalSolutionComponent(Component):
clf.fit(x_train, y_train)
# Find threshold (dynamic or static)
if isinstance(self.threshold_prototype, DynamicThreshold):
self.thresholds[category, label] = self.threshold_prototype.find(
if isinstance(self.threshold_prototype, Threshold):
self.thresholds[category, label] = self.threshold_prototype.fit(
clf,
x_train,
y_train,