diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 00a0140..bb0f8c0 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -146,15 +146,21 @@ class PrimalSolutionComponent(Component): mip_var_values = sample.get_array("mip_var_values") var_features = sample.get_array("lp_var_features") var_names = sample.get_array("static_var_names") + var_types = sample.get_array("static_var_types") var_categories = sample.get_array("static_var_categories") if var_features is None: var_features = sample.get_array("static_var_features") assert instance_features is not None assert var_features is not None assert var_names is not None + assert var_types is not None assert var_categories is not None for (i, var_name) in enumerate(var_names): + # Skip non-binary variables + if var_types[i] != b"B": + continue + # Initialize categories category = var_categories[i] if len(category) == 0: @@ -172,12 +178,6 @@ class PrimalSolutionComponent(Component): if mip_var_values is not None: opt_value = mip_var_values[i] assert opt_value is not None - assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( - f"Variable {var_name} has non-binary value {opt_value} in the " - "optimal solution. Predicting values of non-binary " - "variables is not currently supported. Please set its " - "category to ''." - ) y[category].append([opt_value < 0.5, opt_value >= 0.5]) return x, y diff --git a/setup.py b/setup.py index c82b8ac..a97a332 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ with open("README.md", "r") as fh: setup( name="miplearn", - version="0.2.0.dev11", + version="0.2.0.dev12", author="Alinson S. Xavier", author_email="axavier@anl.gov", description="Extensible framework for Learning-Enhanced Mixed-Integer Optimization", diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index a77cad0..6acebee 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -23,6 +23,7 @@ def sample() -> Sample: sample = MemorySample( { "static_var_names": np.array(["x[0]", "x[1]", "x[2]", "x[3]"], dtype="S"), + "static_var_types": np.array(["B", "B", "B", "B"], dtype="S"), "static_var_categories": np.array( ["default", "", "default", "default"], dtype="S",