mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
LazyStatic: Use dynamic thresholds
This commit is contained in:
@@ -11,6 +11,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
from miplearn import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
|
||||
from miplearn.components.component import Component
|
||||
from miplearn.types import TrainingSample, Features, LearningSolveStats
|
||||
|
||||
@@ -36,13 +37,14 @@ class StaticLazyConstraintsComponent(Component):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: Classifier = CountingClassifier(),
|
||||
threshold: float = 0.05,
|
||||
threshold: Threshold = MinProbabilityThreshold([0.50, 0.50]),
|
||||
violation_tolerance: float = -0.5,
|
||||
) -> None:
|
||||
assert isinstance(classifier, Classifier)
|
||||
self.threshold: float = threshold
|
||||
self.classifier_prototype: Classifier = classifier
|
||||
self.threshold_prototype: Threshold = threshold
|
||||
self.classifiers: Dict[Hashable, Classifier] = {}
|
||||
self.thresholds: Dict[Hashable, Threshold] = {}
|
||||
self.pool: Dict[str, LazyConstraint] = {}
|
||||
self.violation_tolerance: float = violation_tolerance
|
||||
self.enforced_cids: Set[str] = set()
|
||||
@@ -156,9 +158,10 @@ class StaticLazyConstraintsComponent(Component):
|
||||
for category in x.keys():
|
||||
if category not in self.classifiers:
|
||||
continue
|
||||
clf = self.classifiers[category]
|
||||
proba = clf.predict_proba(np.array(x[category]))
|
||||
pred = list(proba[:, 1] > self.threshold)
|
||||
npx = np.array(x[category])
|
||||
proba = self.classifiers[category].predict_proba(npx)
|
||||
thr = self.thresholds[category].predict(npx)
|
||||
pred = list(proba[:, 1] > thr[1])
|
||||
for (i, is_selected) in enumerate(pred):
|
||||
if is_selected:
|
||||
enforced_cids += [category_to_cids[category][i]]
|
||||
@@ -196,4 +199,6 @@ class StaticLazyConstraintsComponent(Component):
|
||||
for c in y.keys():
|
||||
assert c in x
|
||||
self.classifiers[c] = self.classifier_prototype.clone()
|
||||
self.thresholds[c] = self.threshold_prototype.clone()
|
||||
self.classifiers[c].fit(x[c], y[c])
|
||||
self.thresholds[c].fit(self.classifiers[c], x[c], y[c])
|
||||
|
||||
Reference in New Issue
Block a user