|
|
@ -97,8 +97,10 @@ class PrimalSolutionComponent(Component):
|
|
|
|
def sample_predict(self, sample: Sample) -> Solution:
|
|
|
|
def sample_predict(self, sample: Sample) -> Solution:
|
|
|
|
var_names = sample.get_array("static_var_names")
|
|
|
|
var_names = sample.get_array("static_var_names")
|
|
|
|
var_categories = sample.get_array("static_var_categories")
|
|
|
|
var_categories = sample.get_array("static_var_categories")
|
|
|
|
|
|
|
|
var_types = sample.get_array("static_var_types")
|
|
|
|
assert var_names is not None
|
|
|
|
assert var_names is not None
|
|
|
|
assert var_categories is not None
|
|
|
|
assert var_categories is not None
|
|
|
|
|
|
|
|
assert var_types is not None
|
|
|
|
|
|
|
|
|
|
|
|
# Compute y_pred
|
|
|
|
# Compute y_pred
|
|
|
|
x, _ = self.sample_xy(None, sample)
|
|
|
|
x, _ = self.sample_xy(None, sample)
|
|
|
@ -122,6 +124,8 @@ class PrimalSolutionComponent(Component):
|
|
|
|
solution: Solution = {v: None for v in var_names}
|
|
|
|
solution: Solution = {v: None for v in var_names}
|
|
|
|
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
|
|
|
|
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
|
|
|
|
for (i, var_name) in enumerate(var_names):
|
|
|
|
for (i, var_name) in enumerate(var_names):
|
|
|
|
|
|
|
|
if var_types[i] != b"B":
|
|
|
|
|
|
|
|
continue
|
|
|
|
category = var_categories[i]
|
|
|
|
category = var_categories[i]
|
|
|
|
if category not in category_offset:
|
|
|
|
if category not in category_offset:
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|