mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Minor fixes
This commit is contained in:
@@ -97,8 +97,10 @@ class PrimalSolutionComponent(Component):
|
|||||||
def sample_predict(self, sample: Sample) -> Solution:
|
def sample_predict(self, sample: Sample) -> Solution:
|
||||||
var_names = sample.get_array("static_var_names")
|
var_names = sample.get_array("static_var_names")
|
||||||
var_categories = sample.get_array("static_var_categories")
|
var_categories = sample.get_array("static_var_categories")
|
||||||
|
var_types = sample.get_array("static_var_types")
|
||||||
assert var_names is not None
|
assert var_names is not None
|
||||||
assert var_categories is not None
|
assert var_categories is not None
|
||||||
|
assert var_types is not None
|
||||||
|
|
||||||
# Compute y_pred
|
# Compute y_pred
|
||||||
x, _ = self.sample_xy(None, sample)
|
x, _ = self.sample_xy(None, sample)
|
||||||
@@ -122,6 +124,8 @@ class PrimalSolutionComponent(Component):
|
|||||||
solution: Solution = {v: None for v in var_names}
|
solution: Solution = {v: None for v in var_names}
|
||||||
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
|
category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()}
|
||||||
for (i, var_name) in enumerate(var_names):
|
for (i, var_name) in enumerate(var_names):
|
||||||
|
if var_types[i] != b"B":
|
||||||
|
continue
|
||||||
category = var_categories[i]
|
category = var_categories[i]
|
||||||
if category not in category_offset:
|
if category not in category_offset:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -293,6 +293,7 @@ class LearningSolver:
|
|||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
if not discard_output:
|
if not discard_output:
|
||||||
instance.flush()
|
instance.flush()
|
||||||
|
instance.free()
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user