From 35272e08c6772680d952d349b52e8eb088bcde47 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Wed, 18 Aug 2021 10:34:56 -0500 Subject: [PATCH] Primal: Skip non-binary variables --- miplearn/components/primal.py | 12 ++++++------ tests/components/test_primal.py | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 00a0140..bb0f8c0 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -146,15 +146,21 @@ class PrimalSolutionComponent(Component): mip_var_values = sample.get_array("mip_var_values") var_features = sample.get_array("lp_var_features") var_names = sample.get_array("static_var_names") + var_types = sample.get_array("static_var_types") var_categories = sample.get_array("static_var_categories") if var_features is None: var_features = sample.get_array("static_var_features") assert instance_features is not None assert var_features is not None assert var_names is not None + assert var_types is not None assert var_categories is not None for (i, var_name) in enumerate(var_names): + # Skip non-binary variables + if var_types[i] != b"B": + continue + # Initialize categories category = var_categories[i] if len(category) == 0: @@ -172,12 +178,6 @@ class PrimalSolutionComponent(Component): if mip_var_values is not None: opt_value = mip_var_values[i] assert opt_value is not None - assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( - f"Variable {var_name} has non-binary value {opt_value} in the " - "optimal solution. Predicting values of non-binary " - "variables is not currently supported. Please set its " - "category to ''." - ) y[category].append([opt_value < 0.5, opt_value >= 0.5]) return x, y diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index a77cad0..6acebee 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -23,6 +23,7 @@ def sample() -> Sample: sample = MemorySample( { "static_var_names": np.array(["x[0]", "x[1]", "x[2]", "x[3]"], dtype="S"), + "static_var_types": np.array(["B", "B", "B", "B"], dtype="S"), "static_var_categories": np.array( ["default", "", "default", "default"], dtype="S",