Components: Switch from factory methods to prototype objects

This commit is contained in:
2021-04-01 08:34:56 -05:00
parent 59c734f2a1
commit bc8fe4dc98
9 changed files with 43 additions and 34 deletions

View File

@@ -50,18 +50,18 @@ class PrimalSolutionComponent(Component):
def __init__(
self,
classifier: Callable[[], Classifier] = lambda: AdaptiveClassifier(),
classifier: Classifier = AdaptiveClassifier(),
mode: str = "exact",
threshold: Callable[[], Threshold] = lambda: MinPrecisionThreshold(
[0.98, 0.98]
),
threshold: Threshold = MinPrecisionThreshold([0.98, 0.98]),
) -> None:
assert isinstance(classifier, Classifier)
assert isinstance(threshold, Threshold)
assert mode in ["exact", "heuristic"]
self.mode = mode
self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.threshold_factory = threshold
self.classifier_factory = classifier
self.threshold_prototype = threshold
self.classifier_prototype = classifier
self.stats: Dict[str, float] = {}
self._n_free = 0
self._n_zero = 0
@@ -114,8 +114,8 @@ class PrimalSolutionComponent(Component):
y: Dict[str, np.ndarray],
) -> None:
for category in x.keys():
clf = self.classifier_factory()
thr = self.threshold_factory()
clf = self.classifier_prototype.clone()
thr = self.threshold_prototype.clone()
clf.fit(x[category], y[category])
thr.fit(clf, x[category], y[category])
self.classifiers[category] = clf