|
|
|
@ -3,7 +3,7 @@
|
|
|
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
|
|
|
|
|
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
from overrides import overrides
|
|
|
|
@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
|
x: Dict[Hashable, List[List[float]]] = {}
|
|
|
|
|
y: Dict[Hashable, List[List[bool]]] = {}
|
|
|
|
|
cids: Dict[Hashable, List[str]] = {}
|
|
|
|
|
constr_categories_dict = instance.get_constraint_categories()
|
|
|
|
|
constr_features_dict = instance.get_constraint_features()
|
|
|
|
|
for cid in self.known_cids:
|
|
|
|
|
# Initialize categories
|
|
|
|
|
category = instance.get_constraint_category(cid)
|
|
|
|
|
if cid in constr_categories_dict:
|
|
|
|
|
category = constr_categories_dict[cid]
|
|
|
|
|
else:
|
|
|
|
|
category = cid
|
|
|
|
|
if category is None:
|
|
|
|
|
continue
|
|
|
|
|
if category not in x:
|
|
|
|
@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
|
assert sample.after_load is not None
|
|
|
|
|
assert sample.after_load.instance is not None
|
|
|
|
|
features.extend(sample.after_load.instance.to_list())
|
|
|
|
|
features.extend(instance.get_constraint_features(cid))
|
|
|
|
|
if cid in constr_features_dict:
|
|
|
|
|
features.extend(constr_features_dict[cid])
|
|
|
|
|
for ci in features:
|
|
|
|
|
assert isinstance(ci, float), (
|
|
|
|
|
f"Constraint features must be a list of floats. "
|
|
|
|
@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
|
|
|
|
|
tn: Dict[Hashable, int] = {}
|
|
|
|
|
fp: Dict[Hashable, int] = {}
|
|
|
|
|
fn: Dict[Hashable, int] = {}
|
|
|
|
|
constr_categories_dict = instance.get_constraint_categories()
|
|
|
|
|
for cid in self.known_cids:
|
|
|
|
|
category = instance.get_constraint_category(cid)
|
|
|
|
|
if category is None:
|
|
|
|
|
if cid not in constr_categories_dict:
|
|
|
|
|
continue
|
|
|
|
|
category = constr_categories_dict[cid]
|
|
|
|
|
if category not in tp.keys():
|
|
|
|
|
tp[category] = 0
|
|
|
|
|
tn[category] = 0
|
|
|
|
|