mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Refactor thresholds
This commit is contained in:
@@ -11,7 +11,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.adaptive import AdaptiveClassifier
|
||||
from miplearn.classifiers.threshold import MinPrecisionThreshold, DynamicThreshold
|
||||
from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.extractors import VariableFeaturesExtractor, SolutionExtractor, Extractor
|
||||
@@ -28,11 +28,11 @@ class PrimalSolutionComponent(Component):
|
||||
self,
|
||||
classifier: Classifier = AdaptiveClassifier(),
|
||||
mode: str = "exact",
|
||||
threshold: Union[float, DynamicThreshold] = MinPrecisionThreshold(0.98),
|
||||
threshold: Union[float, Threshold] = MinPrecisionThreshold(0.98),
|
||||
) -> None:
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[Any, Classifier] = {}
|
||||
self.thresholds: Dict[Any, Union[float, DynamicThreshold]] = {}
|
||||
self.thresholds: Dict[Any, Union[float, Threshold]] = {}
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
|
||||
@@ -89,8 +89,8 @@ class PrimalSolutionComponent(Component):
|
||||
clf.fit(x_train, y_train)
|
||||
|
||||
# Find threshold (dynamic or static)
|
||||
if isinstance(self.threshold_prototype, DynamicThreshold):
|
||||
self.thresholds[category, label] = self.threshold_prototype.find(
|
||||
if isinstance(self.threshold_prototype, Threshold):
|
||||
self.thresholds[category, label] = self.threshold_prototype.fit(
|
||||
clf,
|
||||
x_train,
|
||||
y_train,
|
||||
|
||||
Reference in New Issue
Block a user