Use np.ndarray for constraint names

This commit is contained in:
2021-08-09 05:41:01 -05:00
parent 45667ac2e4
commit 9ddda7e1e2
7 changed files with 33 additions and 27 deletions

View File

@@ -200,11 +200,11 @@ class FeaturesExtractor:
) -> None:
has_static_lazy = instance.has_static_lazy_constraints()
user_features: List[Optional[List[float]]] = []
categories: List[Optional[str]] = []
categories: List[Optional[bytes]] = []
lazy: List[bool] = []
constr_categories_dict = instance.get_constraint_categories()
constr_features_dict = instance.get_constraint_features()
constr_names = sample.get_vector("static_constr_names")
constr_names = sample.get_array("static_constr_names")
assert constr_names is not None
for (cidx, cname) in enumerate(constr_names):
@@ -215,8 +215,8 @@ class FeaturesExtractor:
user_features.append(None)
categories.append(None)
continue
assert isinstance(category, str), (
f"Constraint category must be a string. "
assert isinstance(category, bytes), (
f"Constraint category must be bytes. "
f"Found {type(category).__name__} instead for cname={cname}.",
)
categories.append(category)
@@ -242,7 +242,7 @@ class FeaturesExtractor:
lazy.append(False)
sample.put_vector_list("static_constr_features", user_features)
sample.put_vector("static_constr_lazy", lazy)
sample.put_vector("static_constr_categories", categories)
sample.put_array("static_constr_categories", np.array(categories, dtype="S"))
def _extract_user_features_instance(
self,