diff --git a/src/python/miplearn/problems/stab.py b/src/python/miplearn/problems/stab.py index 2854792..051e7e5 100644 --- a/src/python/miplearn/problems/stab.py +++ b/src/python/miplearn/problems/stab.py @@ -112,10 +112,21 @@ class MaxWeightStableSetInstance(Instance): return model def get_instance_features(self): - return np.array(self.weights) - - def get_variable_features(self, var, index): return np.ones(0) - + + def get_variable_features(self, var, index): + neighbor_weights = [0] * 15 + neighbor_degrees = [100] * 15 + for n in self.graph.neighbors(index): + neighbor_weights += [self.weights[n] / self.weights[index]] + neighbor_degrees += [self.graph.degree(n) / self.graph.degree(index)] + neighbor_weights.sort(reverse=True) + neighbor_degrees.sort() + features = [] + features += neighbor_weights[:5] + features += neighbor_degrees[:5] + features += [self.graph.degree(index)] + return np.array(features) + def get_variable_category(self, var, index): - return index + return "default"