diff --git a/src/python/miplearn/components/primal.py b/src/python/miplearn/components/primal.py index d430768..a0d993e 100644 --- a/src/python/miplearn/components/primal.py +++ b/src/python/miplearn/components/primal.py @@ -90,7 +90,7 @@ class PrimalSolutionComponent(Component): if (category, label) not in self.classifiers.keys(): continue clf = self.classifiers[category, label] - if isinstance(clf, float): + if isinstance(clf, float) or isinstance(clf, int): ws = np.array([[1 - clf, clf] for _ in range(n)]) else: ws = clf.predict_proba(x_test[category])