Remove {get,put}_set and deprecated functions

This commit is contained in:
2021-08-10 17:27:06 -05:00
parent ed58242b5c
commit 9cfb31bacb
11 changed files with 56 additions and 124 deletions

View File

@@ -56,6 +56,11 @@ class DynamicConstraintsComponent(Component):
cids: Dict[ConstraintCategory, List[ConstraintName]] = {}
known_cids = np.array(self.known_cids, dtype="S")
enforced_cids = None
enforced_cids_np = sample.get_array(self.attr)
if enforced_cids_np is not None:
enforced_cids = list(enforced_cids_np)
# Get user-provided constraint features
(
constr_features,
@@ -72,13 +77,11 @@ class DynamicConstraintsComponent(Component):
constr_features,
]
)
assert len(known_cids) == constr_features.shape[0]
categories = np.unique(constr_categories)
for c in categories:
x[c] = constr_features[constr_categories == c].tolist()
cids[c] = known_cids[constr_categories == c].tolist()
enforced_cids = np.array(list(sample.get_set(self.attr)), dtype="S")
if enforced_cids is not None:
tmp = np.isin(cids[c], enforced_cids).reshape(-1, 1)
y[c] = np.hstack([~tmp, tmp]).tolist() # type: ignore
@@ -99,7 +102,7 @@ class DynamicConstraintsComponent(Component):
assert pre is not None
known_cids: Set = set()
for cids in pre:
known_cids |= cids
known_cids |= set(list(cids))
self.known_cids.clear()
self.known_cids.extend(sorted(known_cids))
@@ -128,7 +131,7 @@ class DynamicConstraintsComponent(Component):
@overrides
def pre_sample_xy(self, instance: Instance, sample: Sample) -> Any:
return sample.get_set(self.attr)
return sample.get_array(self.attr)
@overrides
def fit_xy(
@@ -150,7 +153,7 @@ class DynamicConstraintsComponent(Component):
instance: Instance,
sample: Sample,
) -> Dict[str, float]:
actual = sample.get_set(self.attr)
actual = sample.get_array(self.attr)
assert actual is not None
pred = set(self.sample_predict(instance, sample))
tp, tn, fp, fn = 0, 0, 0, 0