LazyStatic: Use dynamic thresholds

This commit is contained in:
2021-04-04 20:42:04 -05:00
parent 08e808690e
commit 025e08f85e
2 changed files with 32 additions and 13 deletions

View File

@@ -11,6 +11,7 @@ from tqdm.auto import tqdm
from miplearn import Classifier from miplearn import Classifier
from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
from miplearn.components.component import Component from miplearn.components.component import Component
from miplearn.types import TrainingSample, Features, LearningSolveStats from miplearn.types import TrainingSample, Features, LearningSolveStats
@@ -36,13 +37,14 @@ class StaticLazyConstraintsComponent(Component):
def __init__( def __init__(
self, self,
classifier: Classifier = CountingClassifier(), classifier: Classifier = CountingClassifier(),
threshold: float = 0.05, threshold: Threshold = MinProbabilityThreshold([0.50, 0.50]),
violation_tolerance: float = -0.5, violation_tolerance: float = -0.5,
) -> None: ) -> None:
assert isinstance(classifier, Classifier) assert isinstance(classifier, Classifier)
self.threshold: float = threshold
self.classifier_prototype: Classifier = classifier self.classifier_prototype: Classifier = classifier
self.threshold_prototype: Threshold = threshold
self.classifiers: Dict[Hashable, Classifier] = {} self.classifiers: Dict[Hashable, Classifier] = {}
self.thresholds: Dict[Hashable, Threshold] = {}
self.pool: Dict[str, LazyConstraint] = {} self.pool: Dict[str, LazyConstraint] = {}
self.violation_tolerance: float = violation_tolerance self.violation_tolerance: float = violation_tolerance
self.enforced_cids: Set[str] = set() self.enforced_cids: Set[str] = set()
@@ -156,9 +158,10 @@ class StaticLazyConstraintsComponent(Component):
for category in x.keys(): for category in x.keys():
if category not in self.classifiers: if category not in self.classifiers:
continue continue
clf = self.classifiers[category] npx = np.array(x[category])
proba = clf.predict_proba(np.array(x[category])) proba = self.classifiers[category].predict_proba(npx)
pred = list(proba[:, 1] > self.threshold) thr = self.thresholds[category].predict(npx)
pred = list(proba[:, 1] > thr[1])
for (i, is_selected) in enumerate(pred): for (i, is_selected) in enumerate(pred):
if is_selected: if is_selected:
enforced_cids += [category_to_cids[category][i]] enforced_cids += [category_to_cids[category][i]]
@@ -196,4 +199,6 @@ class StaticLazyConstraintsComponent(Component):
for c in y.keys(): for c in y.keys():
assert c in x assert c in x
self.classifiers[c] = self.classifier_prototype.clone() self.classifiers[c] = self.classifier_prototype.clone()
self.thresholds[c] = self.threshold_prototype.clone()
self.classifiers[c].fit(x[c], y[c]) self.classifiers[c].fit(x[c], y[c])
self.thresholds[c].fit(self.classifiers[c], x[c], y[c])

View File

@@ -10,6 +10,7 @@ from numpy.testing import assert_array_equal
from miplearn import LearningSolver, InternalSolver, Instance from miplearn import LearningSolver, InternalSolver, Instance
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
from miplearn.components.lazy_static import StaticLazyConstraintsComponent from miplearn.components.lazy_static import StaticLazyConstraintsComponent
from miplearn.types import TrainingSample, Features, LearningSolveStats from miplearn.types import TrainingSample, Features, LearningSolveStats
@@ -69,10 +70,9 @@ def test_usage_with_solver(features: Features) -> None:
instance = Mock(spec=Instance) instance = Mock(spec=Instance)
instance.has_static_lazy_constraints = Mock(return_value=True) instance.has_static_lazy_constraints = Mock(return_value=True)
component = StaticLazyConstraintsComponent( component = StaticLazyConstraintsComponent(violation_tolerance=1.0)
threshold=0.50, component.thresholds["type-a"] = MinProbabilityThreshold([0.5, 0.5])
violation_tolerance=1.0, component.thresholds["type-b"] = MinProbabilityThreshold([0.5, 0.5])
)
component.classifiers = { component.classifiers = {
"type-a": Mock(spec=Classifier), "type-a": Mock(spec=Classifier),
"type-b": Mock(spec=Classifier), "type-b": Mock(spec=Classifier),
@@ -158,7 +158,9 @@ def test_sample_predict(
features: Features, features: Features,
sample: TrainingSample, sample: TrainingSample,
) -> None: ) -> None:
comp = StaticLazyConstraintsComponent(threshold=0.5) comp = StaticLazyConstraintsComponent()
comp.thresholds["type-a"] = MinProbabilityThreshold([0.5, 0.5])
comp.thresholds["type-b"] = MinProbabilityThreshold([0.5, 0.5])
comp.classifiers["type-a"] = Mock(spec=Classifier) comp.classifiers["type-a"] = Mock(spec=Classifier)
comp.classifiers["type-a"].predict_proba = lambda _: np.array( # type:ignore comp.classifiers["type-a"].predict_proba = lambda _: np.array( # type:ignore
[ [
@@ -192,9 +194,14 @@ def test_fit_xy() -> None:
"type-b": np.array([[False, True]]), "type-b": np.array([[False, True]]),
}, },
) )
clf = Mock(spec=Classifier) clf: Classifier = Mock(spec=Classifier)
clf.clone = Mock(side_effect=lambda: Mock(spec=Classifier)) thr: Threshold = Mock(spec=Threshold)
comp = StaticLazyConstraintsComponent(classifier=clf) clf.clone = Mock(side_effect=lambda: Mock(spec=Classifier)) # type: ignore
thr.clone = Mock(side_effect=lambda: Mock(spec=Threshold)) # type: ignore
comp = StaticLazyConstraintsComponent(
classifier=clf,
threshold=thr,
)
comp.fit_xy(x, y) comp.fit_xy(x, y)
assert clf.clone.call_count == 2 assert clf.clone.call_count == 2
clf_a = comp.classifiers["type-a"] clf_a = comp.classifiers["type-a"]
@@ -203,6 +210,13 @@ def test_fit_xy() -> None:
assert clf_b.fit.call_count == 1 # type: ignore assert clf_b.fit.call_count == 1 # type: ignore
assert_array_equal(clf_a.fit.call_args[0][0], x["type-a"]) # type: ignore assert_array_equal(clf_a.fit.call_args[0][0], x["type-a"]) # type: ignore
assert_array_equal(clf_b.fit.call_args[0][0], x["type-b"]) # type: ignore assert_array_equal(clf_b.fit.call_args[0][0], x["type-b"]) # type: ignore
assert thr.clone.call_count == 2
thr_a = comp.thresholds["type-a"]
thr_b = comp.thresholds["type-b"]
assert thr_a.fit.call_count == 1 # type: ignore
assert thr_b.fit.call_count == 1 # type: ignore
assert thr_a.fit.call_args[0][0] == clf_a # type: ignore
assert thr_b.fit.call_args[0][0] == clf_b # type: ignore
def test_sample_xy( def test_sample_xy(