Request constraint features/categories in bulk

This commit is contained in:
2021-06-29 09:54:35 -05:00
parent 8118ab4110
commit a5092cc2b9
6 changed files with 51 additions and 38 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Dict, Hashable, List, Tuple, Optional, Any, FrozenSet, Set
from typing import Dict, Hashable, List, Tuple, Optional, Any, Set
import numpy as np
from overrides import overrides
@@ -50,9 +50,14 @@ class DynamicConstraintsComponent(Component):
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[bool]]] = {}
cids: Dict[Hashable, List[str]] = {}
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
for cid in self.known_cids:
# Initialize categories
category = instance.get_constraint_category(cid)
if cid in constr_categories_dict:
category = constr_categories_dict[cid]
else:
category = cid
if category is None:
continue
if category not in x:
@@ -65,7 +70,8 @@ class DynamicConstraintsComponent(Component):
assert sample.after_load is not None
assert sample.after_load.instance is not None
features.extend(sample.after_load.instance.to_list())
features.extend(instance.get_constraint_features(cid))
if cid in constr_features_dict:
features.extend(constr_features_dict[cid])
for ci in features:
assert isinstance(ci, float), (
f"Constraint features must be a list of floats. "
@@ -164,10 +170,11 @@ class DynamicConstraintsComponent(Component):
tn: Dict[Hashable, int] = {}
fp: Dict[Hashable, int] = {}
fn: Dict[Hashable, int] = {}
constr_categories_dict = instance.get_constraint_categories()
for cid in self.known_cids:
category = instance.get_constraint_category(cid)
if category is None:
if cid not in constr_categories_dict:
continue
category = constr_categories_dict[cid]
if category not in tp.keys():
tp[category] = 0
tn[category] = 0