|
|
@ -48,6 +48,8 @@ class UserCutsComponent(Component):
|
|
|
|
self.classifiers = {}
|
|
|
|
self.classifiers = {}
|
|
|
|
violation_to_instance_idx = {}
|
|
|
|
violation_to_instance_idx = {}
|
|
|
|
for (idx, instance) in enumerate(training_instances):
|
|
|
|
for (idx, instance) in enumerate(training_instances):
|
|
|
|
|
|
|
|
if not hasattr(instance, "found_violated_user_cuts"):
|
|
|
|
|
|
|
|
continue
|
|
|
|
for v in instance.found_violated_user_cuts:
|
|
|
|
for v in instance.found_violated_user_cuts:
|
|
|
|
if v not in self.classifiers:
|
|
|
|
if v not in self.classifiers:
|
|
|
|
self.classifiers[v] = deepcopy(self.classifier_prototype)
|
|
|
|
self.classifiers[v] = deepcopy(self.classifier_prototype)
|
|
|
|