mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Components: Switch from factory methods to prototype objects
This commit is contained in:
@@ -50,18 +50,18 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Callable[[], Classifier] = lambda: AdaptiveClassifier(),
|
||||
classifier: Classifier = AdaptiveClassifier(),
|
||||
mode: str = "exact",
|
||||
threshold: Callable[[], Threshold] = lambda: MinPrecisionThreshold(
|
||||
[0.98, 0.98]
|
||||
),
|
||||
threshold: Threshold = MinPrecisionThreshold([0.98, 0.98]),
|
||||
) -> None:
|
||||
assert isinstance(classifier, Classifier)
|
||||
assert isinstance(threshold, Threshold)
|
||||
assert mode in ["exact", "heuristic"]
|
||||
self.mode = mode
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.threshold_factory = threshold
|
||||
self.classifier_factory = classifier
|
||||
self.threshold_prototype = threshold
|
||||
self.classifier_prototype = classifier
|
||||
self.stats: Dict[str, float] = {}
|
||||
self._n_free = 0
|
||||
self._n_zero = 0
|
||||
@@ -114,8 +114,8 @@ class PrimalSolutionComponent(Component):
|
||||
y: Dict[str, np.ndarray],
|
||||
) -> None:
|
||||
for category in x.keys():
|
||||
clf = self.classifier_factory()
|
||||
thr = self.threshold_factory()
|
||||
clf = self.classifier_prototype.clone()
|
||||
thr = self.threshold_prototype.clone()
|
||||
clf.fit(x[category], y[category])
|
||||
thr.fit(clf, x[category], y[category])
|
||||
self.classifiers[category] = clf
|
||||
|
||||
Reference in New Issue
Block a user