diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index bb0f8c0..283501e 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -97,8 +97,10 @@ class PrimalSolutionComponent(Component): def sample_predict(self, sample: Sample) -> Solution: var_names = sample.get_array("static_var_names") var_categories = sample.get_array("static_var_categories") + var_types = sample.get_array("static_var_types") assert var_names is not None assert var_categories is not None + assert var_types is not None # Compute y_pred x, _ = self.sample_xy(None, sample) @@ -122,6 +124,8 @@ class PrimalSolutionComponent(Component): solution: Solution = {v: None for v in var_names} category_offset: Dict[Category, int] = {cat: 0 for cat in x.keys()} for (i, var_name) in enumerate(var_names): + if var_types[i] != b"B": + continue category = var_categories[i] if category not in category_offset: continue diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index 5e54bee..753a228 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -293,6 +293,7 @@ class LearningSolver: # ------------------------------------------------------- if not discard_output: instance.flush() + instance.free() return stats