mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Minor fixes
This commit is contained in:
@@ -97,8 +97,10 @@ class PrimalSolutionComponent(Component):
|
||||
def sample_predict(self, sample: Sample) -> Solution:
|
||||
var_names = sample.get_array("static_var_names")
|
||||
var_categories = sample.get_array("static_var_categories")
|
||||
var_types = sample.get_array("static_var_types")
|
||||
assert var_names is not None
|
||||
assert var_categories is not None
|
||||
assert var_types is not None
|
||||
|
||||
# Compute y_pred
|
||||
x, _ = self.sample_xy(None, sample)
|
||||
@@ -122,6 +124,8 @@ class PrimalSolutionComponent(Component):
|
||||
solution: Solution = {v: None for v in var_names}
|
||||
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
|
||||
for (i, var_name) in enumerate(var_names):
|
||||
if var_types[i] != b"B":
|
||||
continue
|
||||
category = var_categories[i]
|
||||
if category not in category_offset:
|
||||
continue
|
||||
|
||||
@@ -293,6 +293,7 @@ class LearningSolver:
|
||||
# -------------------------------------------------------
|
||||
if not discard_output:
|
||||
instance.flush()
|
||||
instance.free()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
Reference in New Issue
Block a user