@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict , Hashable, List, Tuple , Optional , Any , Set
from typing import Dict , List, Tuple , Optional , Any , Set
import numpy as np
from overrides import overrides
@ -32,8 +32,8 @@ class DynamicConstraintsComponent(Component):
assert isinstance ( classifier , Classifier )
self . threshold_prototype : Threshold = threshold
self . classifier_prototype : Classifier = classifier
self . classifiers : Dict [ Hashable , Classifier ] = { }
self . thresholds : Dict [ Hashable , Threshold ] = { }
self . classifiers : Dict [ str , Classifier ] = { }
self . thresholds : Dict [ str , Threshold ] = { }
self . known_cids : List [ str ] = [ ]
self . attr = attr
@ -42,14 +42,14 @@ class DynamicConstraintsComponent(Component):
instance : Optional [ Instance ] ,
sample : Sample ,
) - > Tuple [
Dict [ Hashable , List [ List [ float ] ] ] ,
Dict [ Hashable , List [ List [ bool ] ] ] ,
Dict [ Hashable , List [ str ] ] ,
Dict [ str , List [ List [ float ] ] ] ,
Dict [ str , List [ List [ bool ] ] ] ,
Dict [ str , List [ str ] ] ,
] :
assert instance is not None
x : Dict [ Hashable , List [ List [ float ] ] ] = { }
y : Dict [ Hashable , List [ List [ bool ] ] ] = { }
cids : Dict [ Hashable , List [ str ] ] = { }
x : Dict [ str , List [ List [ float ] ] ] = { }
y : Dict [ str , List [ List [ bool ] ] ] = { }
cids : Dict [ str , List [ str ] ] = { }
constr_categories_dict = instance . get_constraint_categories ( )
constr_features_dict = instance . get_constraint_features ( )
instance_features = sample . get ( " instance_features_user " )
@ -111,8 +111,8 @@ class DynamicConstraintsComponent(Component):
self ,
instance : Instance ,
sample : Sample ,
) - > List [ Hashable ] :
pred : List [ Hashable ] = [ ]
) - > List [ str ] :
pred : List [ str ] = [ ]
if len ( self . known_cids ) == 0 :
logger . info ( " Classifiers not fitted. Skipping. " )
return pred
@ -137,8 +137,8 @@ class DynamicConstraintsComponent(Component):
@overrides
def fit_xy (
self ,
x : Dict [ Hashable , np . ndarray ] ,
y : Dict [ Hashable , np . ndarray ] ,
x : Dict [ str , np . ndarray ] ,
y : Dict [ str , np . ndarray ] ,
) - > None :
for category in x . keys ( ) :
self . classifiers [ category ] = self . classifier_prototype . clone ( )
@ -153,14 +153,14 @@ class DynamicConstraintsComponent(Component):
self ,
instance : Instance ,
sample : Sample ,
) - > Dict [ Hashable , Dict [ str , float ] ] :
) - > Dict [ str , Dict [ str , float ] ] :
actual = sample . get ( self . attr )
assert actual is not None
pred = set ( self . sample_predict ( instance , sample ) )
tp : Dict [ Hashable , int ] = { }
tn : Dict [ Hashable , int ] = { }
fp : Dict [ Hashable , int ] = { }
fn : Dict [ Hashable , int ] = { }
tp : Dict [ str , int ] = { }
tn : Dict [ str , int ] = { }
fp : Dict [ str , int ] = { }
fn : Dict [ str , int ] = { }
constr_categories_dict = instance . get_constraint_categories ( )
for cid in self . known_cids :
if cid not in constr_categories_dict :