Minor fixes

This commit is contained in:
2021-09-04 06:31:37 -05:00
parent 65122c25b7
commit 9bd64c885a
2 changed files with 5 additions and 0 deletions

View File

@@ -97,8 +97,10 @@ class PrimalSolutionComponent(Component):
def sample_predict(self, sample: Sample) -> Solution:
var_names = sample.get_array("static_var_names")
var_categories = sample.get_array("static_var_categories")
var_types = sample.get_array("static_var_types")
assert var_names is not None
assert var_categories is not None
assert var_types is not None
# Compute y_pred
x, _ = self.sample_xy(None, sample)
@@ -122,6 +124,8 @@ class PrimalSolutionComponent(Component):
solution: Solution = {v: None for v in var_names}
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
for (i, var_name) in enumerate(var_names):
if var_types[i] != b"B":
continue
category = var_categories[i]
if category not in category_offset:
continue