diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 4a9f4b3..6c2cc84 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -119,6 +119,8 @@ class PrimalSolutionComponent(Component): vars_all, vars_one, vars_zero = set(), set(), set() pred_one_positive, pred_zero_positive = set(), set() for (varname, var_dict) in solution_actual.items(): + if varname not in solution_pred.keys(): + continue for (idx, value) in var_dict.items(): vars_all.add((varname, idx)) if value > 0.5: