mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Add method LazyConstraintsComponent.evaluate
This commit is contained in:
@@ -29,13 +29,7 @@ class LazyConstraintsComponent(Component):
|
|||||||
|
|
||||||
def before_solve(self, solver, instance, model):
|
def before_solve(self, solver, instance, model):
|
||||||
logger.info("Predicting violated lazy constraints...")
|
logger.info("Predicting violated lazy constraints...")
|
||||||
violations = []
|
violations = self.predict(instance)
|
||||||
features = InstanceFeaturesExtractor().extract([instance])
|
|
||||||
for (v, classifier) in self.classifiers.items():
|
|
||||||
proba = classifier.predict_proba(features)
|
|
||||||
if proba[0][1] > self.threshold:
|
|
||||||
violations += [v]
|
|
||||||
|
|
||||||
logger.info("Enforcing %d constraints..." % len(violations))
|
logger.info("Enforcing %d constraints..." % len(violations))
|
||||||
for v in violations:
|
for v in violations:
|
||||||
cut = instance.build_lazy_constraint(model, v)
|
cut = instance.build_lazy_constraint(model, v)
|
||||||
@@ -57,11 +51,70 @@ class LazyConstraintsComponent(Component):
|
|||||||
violation_to_instance_idx[v] = []
|
violation_to_instance_idx[v] = []
|
||||||
violation_to_instance_idx[v] += [idx]
|
violation_to_instance_idx[v] += [idx]
|
||||||
|
|
||||||
for (v, classifier) in self.classifiers.items():
|
for (v, classifier) in tqdm(self.classifiers.items(), desc="Fit (lazy)"):
|
||||||
logger.debug("Training: %s" % (str(v)))
|
logger.debug("Training: %s" % (str(v)))
|
||||||
label = np.zeros(len(training_instances))
|
label = np.zeros(len(training_instances))
|
||||||
label[violation_to_instance_idx[v]] = 1.0
|
label[violation_to_instance_idx[v]] = 1.0
|
||||||
classifier.fit(features, label)
|
classifier.fit(features, label)
|
||||||
|
|
||||||
def predict(self, instance, model=None):
|
def predict(self, instance):
|
||||||
return self.violations
|
violations = []
|
||||||
|
features = InstanceFeaturesExtractor().extract([instance])
|
||||||
|
for (v, classifier) in self.classifiers.items():
|
||||||
|
proba = classifier.predict_proba(features)
|
||||||
|
if proba[0][1] > self.threshold:
|
||||||
|
violations += [v]
|
||||||
|
return violations
|
||||||
|
|
||||||
|
def evaluate(self, instances):
|
||||||
|
|
||||||
|
def _classifier_evaluation_dict(tp, tn, fp, fn):
|
||||||
|
p = tp + fn
|
||||||
|
n = fp + tn
|
||||||
|
d = {
|
||||||
|
"Predicted positive": fp + tp,
|
||||||
|
"Predicted negative": fn + tn,
|
||||||
|
"Condition positive": p,
|
||||||
|
"Condition negative": n,
|
||||||
|
"True positive": tp,
|
||||||
|
"True negative": tn,
|
||||||
|
"False positive": fp,
|
||||||
|
"False negative": fn,
|
||||||
|
}
|
||||||
|
d["Accuracy"] = (tp + tn) / (p + n)
|
||||||
|
d["F1 score"] = (2 * tp) / (2 * tp + fp + fn)
|
||||||
|
d["Recall"] = tp / p
|
||||||
|
d["Precision"] = tp / (tp + fp)
|
||||||
|
T = (p + n) / 100.0
|
||||||
|
d["Predicted positive (%)"] = d["Predicted positive"] / T
|
||||||
|
d["Predicted negative (%)"] = d["Predicted negative"] / T
|
||||||
|
d["Condition positive (%)"] = d["Condition positive"] / T
|
||||||
|
d["Condition negative (%)"] = d["Condition negative"] / T
|
||||||
|
d["True positive (%)"] = d["True positive"] / T
|
||||||
|
d["True negative (%)"] = d["True negative"] / T
|
||||||
|
d["False positive (%)"] = d["False positive"] / T
|
||||||
|
d["False negative (%)"] = d["False negative"] / T
|
||||||
|
return d
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
all_violations = set()
|
||||||
|
for instance in instances:
|
||||||
|
all_violations |= set(instance.found_violations)
|
||||||
|
|
||||||
|
for idx in tqdm(range(len(instances)), desc="Evaluate (lazy)"):
|
||||||
|
instance = instances[idx]
|
||||||
|
condition_positive = set(instance.found_violations)
|
||||||
|
condition_negative = all_violations - condition_positive
|
||||||
|
pred_positive = set(self.predict(instance)) & all_violations
|
||||||
|
pred_negative = all_violations - pred_positive
|
||||||
|
|
||||||
|
tp = len(pred_positive & condition_positive)
|
||||||
|
tn = len(pred_negative & condition_negative)
|
||||||
|
fp = len(pred_positive & condition_negative)
|
||||||
|
fn = len(pred_negative & condition_positive)
|
||||||
|
|
||||||
|
results[idx] = _classifier_evaluation_dict(tp, tn, fp, fn)
|
||||||
|
|
||||||
|
|
||||||
|
return results
|
||||||
@@ -77,3 +77,63 @@ def test_lazy_before():
|
|||||||
|
|
||||||
# Should ask internal solver to add generated constraint
|
# Should ask internal solver to add generated constraint
|
||||||
solver.internal_solver.add_constraint.assert_called_once_with("c1")
|
solver.internal_solver.add_constraint.assert_called_once_with("c1")
|
||||||
|
|
||||||
|
def test_lazy_evaluate():
|
||||||
|
instances, models = get_training_instances_and_models()
|
||||||
|
component = LazyConstraintsComponent()
|
||||||
|
component.classifiers = {"a": Mock(spec=Classifier),
|
||||||
|
"b": Mock(spec=Classifier),
|
||||||
|
"c": Mock(spec=Classifier)}
|
||||||
|
component.classifiers["a"].predict_proba = Mock(return_value=[[1.0, 0.0]])
|
||||||
|
component.classifiers["b"].predict_proba = Mock(return_value=[[0.0, 1.0]])
|
||||||
|
component.classifiers["c"].predict_proba = Mock(return_value=[[0.0, 1.0]])
|
||||||
|
|
||||||
|
instances[0].found_violations = ["a", "b", "c"]
|
||||||
|
instances[1].found_violations = ["b", "d"]
|
||||||
|
assert component.evaluate(instances) == {
|
||||||
|
0: {
|
||||||
|
"Accuracy": 0.75,
|
||||||
|
"F1 score": 0.8,
|
||||||
|
"Precision": 1.0,
|
||||||
|
"Recall": 2/3.,
|
||||||
|
"Predicted positive": 2,
|
||||||
|
"Predicted negative": 2,
|
||||||
|
"Condition positive": 3,
|
||||||
|
"Condition negative": 1,
|
||||||
|
"False negative": 1,
|
||||||
|
"False positive": 0,
|
||||||
|
"True negative": 1,
|
||||||
|
"True positive": 2,
|
||||||
|
"Predicted positive (%)": 50.0,
|
||||||
|
"Predicted negative (%)": 50.0,
|
||||||
|
"Condition positive (%)": 75.0,
|
||||||
|
"Condition negative (%)": 25.0,
|
||||||
|
"False negative (%)": 25.0,
|
||||||
|
"False positive (%)": 0,
|
||||||
|
"True negative (%)": 25.0,
|
||||||
|
"True positive (%)": 50.0,
|
||||||
|
},
|
||||||
|
1: {
|
||||||
|
"Accuracy": 0.5,
|
||||||
|
"F1 score": 0.5,
|
||||||
|
"Precision": 0.5,
|
||||||
|
"Recall": 0.5,
|
||||||
|
"Predicted positive": 2,
|
||||||
|
"Predicted negative": 2,
|
||||||
|
"Condition positive": 2,
|
||||||
|
"Condition negative": 2,
|
||||||
|
"False negative": 1,
|
||||||
|
"False positive": 1,
|
||||||
|
"True negative": 1,
|
||||||
|
"True positive": 1,
|
||||||
|
"Predicted positive (%)": 50.0,
|
||||||
|
"Predicted negative (%)": 50.0,
|
||||||
|
"Condition positive (%)": 50.0,
|
||||||
|
"Condition negative (%)": 50.0,
|
||||||
|
"False negative (%)": 25.0,
|
||||||
|
"False positive (%)": 25.0,
|
||||||
|
"True negative (%)": 25.0,
|
||||||
|
"True positive (%)": 25.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user