LazyStatic: Use dynamic thresholds

master
Alinson S. Xavier 5 years ago
parent 08e808690e
commit 025e08f85e
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -11,6 +11,7 @@ from tqdm.auto import tqdm
from miplearn import Classifier from miplearn import Classifier
from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
from miplearn.components.component import Component from miplearn.components.component import Component
from miplearn.types import TrainingSample, Features, LearningSolveStats from miplearn.types import TrainingSample, Features, LearningSolveStats
@ -36,13 +37,14 @@ class StaticLazyConstraintsComponent(Component):
def __init__( def __init__(
self, self,
classifier: Classifier = CountingClassifier(), classifier: Classifier = CountingClassifier(),
threshold: float = 0.05, threshold: Threshold = MinProbabilityThreshold([0.50, 0.50]),
violation_tolerance: float = -0.5, violation_tolerance: float = -0.5,
) -> None: ) -> None:
assert isinstance(classifier, Classifier) assert isinstance(classifier, Classifier)
self.threshold: float = threshold
self.classifier_prototype: Classifier = classifier self.classifier_prototype: Classifier = classifier
self.threshold_prototype: Threshold = threshold
self.classifiers: Dict[Hashable, Classifier] = {} self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.pool: Dict[str, LazyConstraint] = {} self.pool: Dict[str, LazyConstraint] = {}
self.violation_tolerance: float = violation_tolerance self.violation_tolerance: float = violation_tolerance
self.enforced_cids: Set[str] = set() self.enforced_cids: Set[str] = set()
@ -156,9 +158,10 @@ class StaticLazyConstraintsComponent(Component):
for category in x.keys(): for category in x.keys():
if category not in self.classifiers: if category not in self.classifiers:
continue continue
clf = self.classifiers[category] npx = np.array(x[category])
proba = clf.predict_proba(np.array(x[category])) proba = self.classifiers[category].predict_proba(npx)
pred = list(proba[:, 1] > self.threshold) thr = self.thresholds[category].predict(npx)
pred = list(proba[:, 1] > thr[1])
for (i, is_selected) in enumerate(pred): for (i, is_selected) in enumerate(pred):
if is_selected: if is_selected:
enforced_cids += [category_to_cids[category][i]] enforced_cids += [category_to_cids[category][i]]
@ -196,4 +199,6 @@ class StaticLazyConstraintsComponent(Component):
for c in y.keys(): for c in y.keys():
assert c in x assert c in x
self.classifiers[c] = self.classifier_prototype.clone() self.classifiers[c] = self.classifier_prototype.clone()
self.thresholds[c] = self.threshold_prototype.clone()
self.classifiers[c].fit(x[c], y[c]) self.classifiers[c].fit(x[c], y[c])
self.thresholds[c].fit(self.classifiers[c], x[c], y[c])

@ -10,6 +10,7 @@ from numpy.testing import assert_array_equal
from miplearn import LearningSolver, InternalSolver, Instance from miplearn import LearningSolver, InternalSolver, Instance
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
from miplearn.components.lazy_static import StaticLazyConstraintsComponent from miplearn.components.lazy_static import StaticLazyConstraintsComponent
from miplearn.types import TrainingSample, Features, LearningSolveStats from miplearn.types import TrainingSample, Features, LearningSolveStats
@ -69,10 +70,9 @@ def test_usage_with_solver(features: Features) -> None:
instance = Mock(spec=Instance) instance = Mock(spec=Instance)
instance.has_static_lazy_constraints = Mock(return_value=True) instance.has_static_lazy_constraints = Mock(return_value=True)
component = StaticLazyConstraintsComponent( component = StaticLazyConstraintsComponent(violation_tolerance=1.0)
threshold=0.50, component.thresholds["type-a"] = MinProbabilityThreshold([0.5, 0.5])
violation_tolerance=1.0, component.thresholds["type-b"] = MinProbabilityThreshold([0.5, 0.5])
)
component.classifiers = { component.classifiers = {
"type-a": Mock(spec=Classifier), "type-a": Mock(spec=Classifier),
"type-b": Mock(spec=Classifier), "type-b": Mock(spec=Classifier),
@ -158,7 +158,9 @@ def test_sample_predict(
features: Features, features: Features,
sample: TrainingSample, sample: TrainingSample,
) -> None: ) -> None:
comp = StaticLazyConstraintsComponent(threshold=0.5) comp = StaticLazyConstraintsComponent()
comp.thresholds["type-a"] = MinProbabilityThreshold([0.5, 0.5])
comp.thresholds["type-b"] = MinProbabilityThreshold([0.5, 0.5])
comp.classifiers["type-a"] = Mock(spec=Classifier) comp.classifiers["type-a"] = Mock(spec=Classifier)
comp.classifiers["type-a"].predict_proba = lambda _: np.array( # type:ignore comp.classifiers["type-a"].predict_proba = lambda _: np.array( # type:ignore
[ [
@ -192,9 +194,14 @@ def test_fit_xy() -> None:
"type-b": np.array([[False, True]]), "type-b": np.array([[False, True]]),
}, },
) )
clf = Mock(spec=Classifier) clf: Classifier = Mock(spec=Classifier)
clf.clone = Mock(side_effect=lambda: Mock(spec=Classifier)) thr: Threshold = Mock(spec=Threshold)
comp = StaticLazyConstraintsComponent(classifier=clf) clf.clone = Mock(side_effect=lambda: Mock(spec=Classifier)) # type: ignore
thr.clone = Mock(side_effect=lambda: Mock(spec=Threshold)) # type: ignore
comp = StaticLazyConstraintsComponent(
classifier=clf,
threshold=thr,
)
comp.fit_xy(x, y) comp.fit_xy(x, y)
assert clf.clone.call_count == 2 assert clf.clone.call_count == 2
clf_a = comp.classifiers["type-a"] clf_a = comp.classifiers["type-a"]
@ -203,6 +210,13 @@ def test_fit_xy() -> None:
assert clf_b.fit.call_count == 1 # type: ignore assert clf_b.fit.call_count == 1 # type: ignore
assert_array_equal(clf_a.fit.call_args[0][0], x["type-a"]) # type: ignore assert_array_equal(clf_a.fit.call_args[0][0], x["type-a"]) # type: ignore
assert_array_equal(clf_b.fit.call_args[0][0], x["type-b"]) # type: ignore assert_array_equal(clf_b.fit.call_args[0][0], x["type-b"]) # type: ignore
assert thr.clone.call_count == 2
thr_a = comp.thresholds["type-a"]
thr_b = comp.thresholds["type-b"]
assert thr_a.fit.call_count == 1 # type: ignore
assert thr_b.fit.call_count == 1 # type: ignore
assert thr_a.fit.call_args[0][0] == clf_a # type: ignore
assert thr_b.fit.call_args[0][0] == clf_b # type: ignore
def test_sample_xy( def test_sample_xy(

Loading…
Cancel
Save