Minor fixes

master
Alinson S. Xavier 4 years ago
parent 65122c25b7
commit 9bd64c885a

@ -97,8 +97,10 @@ class PrimalSolutionComponent(Component):
def sample_predict(self, sample: Sample) -> Solution: def sample_predict(self, sample: Sample) -> Solution:
var_names = sample.get_array("static_var_names") var_names = sample.get_array("static_var_names")
var_categories = sample.get_array("static_var_categories") var_categories = sample.get_array("static_var_categories")
var_types = sample.get_array("static_var_types")
assert var_names is not None assert var_names is not None
assert var_categories is not None assert var_categories is not None
assert var_types is not None
# Compute y_pred # Compute y_pred
x, _ = self.sample_xy(None, sample) x, _ = self.sample_xy(None, sample)
@ -122,6 +124,8 @@ class PrimalSolutionComponent(Component):
solution: Solution = {v: None for v in var_names} solution: Solution = {v: None for v in var_names}
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()} category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
for (i, var_name) in enumerate(var_names): for (i, var_name) in enumerate(var_names):
if var_types[i] != b"B":
continue
category = var_categories[i] category = var_categories[i]
if category not in category_offset: if category not in category_offset:
continue continue

@ -293,6 +293,7 @@ class LearningSolver:
# ------------------------------------------------------- # -------------------------------------------------------
if not discard_output: if not discard_output:
instance.flush() instance.flush()
instance.free()
return stats return stats

Loading…
Cancel
Save