diff --git a/src/python/miplearn/components/primal.py b/src/python/miplearn/components/primal.py index e31887c..33b8e56 100644 --- a/src/python/miplearn/components/primal.py +++ b/src/python/miplearn/components/primal.py @@ -18,6 +18,7 @@ class PrimalSolutionComponent(Component): """ A component that predicts primal solutions. """ + def __init__(self, classifier=AdaptiveClassifier(), mode="exact", @@ -33,22 +34,22 @@ class PrimalSolutionComponent(Component): self.classifiers = {} self.classifier_prototype = classifier self.dynamic_thresholds = dynamic_thresholds - + def before_solve(self, solver, instance, model): solution = self.predict(instance) if self.mode == "heuristic": solver.internal_solver.fix(solution) else: solver.internal_solver.set_warm_start(solution) - + def after_solve(self, solver, instance, model, results): pass - + def fit(self, training_instances): logger.debug("Extracting features...") features = VariableFeaturesExtractor().extract(training_instances) solutions = SolutionExtractor().extract(training_instances) - + for category in tqdm(features.keys(), desc="Fit (Primal)"): x_train = features[category] y_train = solutions[category] @@ -69,11 +70,11 @@ class PrimalSolutionComponent(Component): self.thresholds[category, label] = self.min_threshold[label] logger.debug(" Setting threshold to %.4f" % self.min_threshold[label]) continue - + proba = pred.predict_proba(x_train) assert isinstance(proba, np.ndarray), \ "classifier should return numpy array" - assert proba.shape == (x_train.shape[0], 2),\ + assert proba.shape == (x_train.shape[0], 2), \ "classifier should return (%d,%d)-shaped array, not %s" % ( x_train.shape[0], 2, str(proba.shape)) @@ -89,10 +90,10 @@ class PrimalSolutionComponent(Component): if thresholds[k + 1] < self.min_threshold[label]: break k = k + 1 - logger.debug(" Setting threshold to %.4f (fpr=%.4f, tpr=%.4f)"% + logger.debug(" Setting threshold to %.4f (fpr=%.4f, tpr=%.4f)" % (thresholds[k], fpr[k], tpr[k])) self.thresholds[category, label] = thresholds[k] - + def predict(self, instance): x_test = VariableFeaturesExtractor().extract([instance]) solution = {} @@ -113,7 +114,8 @@ class PrimalSolutionComponent(Component): return solution def evaluate(self, instances): - ev = {} + ev = {"Fix zero": {}, + "Fix one": {}} for instance_idx in tqdm(range(len(instances))): instance = instances[instance_idx] solution_actual = instance.solution @@ -146,8 +148,6 @@ class PrimalSolutionComponent(Component): tn_one = len(pred_one_negative & vars_zero) fn_one = len(pred_one_negative & vars_one) - ev[instance_idx] = { - "Fix zero": classifier_evaluation_dict(tp_zero, tn_zero, fp_zero, fn_zero), - "Fix one": classifier_evaluation_dict(tp_one, tn_one, fp_one, fn_one), - } + ev["Fix zero"][instance_idx] = classifier_evaluation_dict(tp_zero, tn_zero, fp_zero, fn_zero) + ev["Fix one"][instance_idx] = classifier_evaluation_dict(tp_one, tn_one, fp_one, fn_one) return ev diff --git a/src/python/miplearn/components/tests/test_primal.py b/src/python/miplearn/components/tests/test_primal.py index 0116b65..c926494 100644 --- a/src/python/miplearn/components/tests/test_primal.py +++ b/src/python/miplearn/components/tests/test_primal.py @@ -50,7 +50,7 @@ def test_evaluate(): 2: 1, 3: 1}} ev = comp.evaluate(instances[:1]) - assert ev == {0: {'Fix one': {'Accuracy': 0.5, + assert ev == {'Fix one': {0: {'Accuracy': 0.5, 'Condition negative': 1, 'Condition negative (%)': 25.0, 'Condition positive': 3, @@ -69,8 +69,8 @@ def test_evaluate(): 'True negative': 1, 'True negative (%)': 25.0, 'True positive': 1, - 'True positive (%)': 25.0}, - 'Fix zero': {'Accuracy': 0.75, + 'True positive (%)': 25.0}}, + 'Fix zero': {0: {'Accuracy': 0.75, 'Condition negative': 3, 'Condition negative (%)': 75.0, 'Condition positive': 1,