Primal: Skip non-binary variables

master
Alinson S. Xavier 4 years ago
parent 5b3a56f053
commit 35272e08c6

@ -146,15 +146,21 @@ class PrimalSolutionComponent(Component):
mip_var_values = sample.get_array("mip_var_values") mip_var_values = sample.get_array("mip_var_values")
var_features = sample.get_array("lp_var_features") var_features = sample.get_array("lp_var_features")
var_names = sample.get_array("static_var_names") var_names = sample.get_array("static_var_names")
var_types = sample.get_array("static_var_types")
var_categories = sample.get_array("static_var_categories") var_categories = sample.get_array("static_var_categories")
if var_features is None: if var_features is None:
var_features = sample.get_array("static_var_features") var_features = sample.get_array("static_var_features")
assert instance_features is not None assert instance_features is not None
assert var_features is not None assert var_features is not None
assert var_names is not None assert var_names is not None
assert var_types is not None
assert var_categories is not None assert var_categories is not None
for (i, var_name) in enumerate(var_names): for (i, var_name) in enumerate(var_names):
# Skip non-binary variables
if var_types[i] != b"B":
continue
# Initialize categories # Initialize categories
category = var_categories[i] category = var_categories[i]
if len(category) == 0: if len(category) == 0:
@ -172,12 +178,6 @@ class PrimalSolutionComponent(Component):
if mip_var_values is not None: if mip_var_values is not None:
opt_value = mip_var_values[i] opt_value = mip_var_values[i]
assert opt_value is not None assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to ''."
)
y[category].append([opt_value < 0.5, opt_value >= 0.5]) y[category].append([opt_value < 0.5, opt_value >= 0.5])
return x, y return x, y

@ -23,6 +23,7 @@ def sample() -> Sample:
sample = MemorySample( sample = MemorySample(
{ {
"static_var_names": np.array(["x[0]", "x[1]", "x[2]", "x[3]"], dtype="S"), "static_var_names": np.array(["x[0]", "x[1]", "x[2]", "x[3]"], dtype="S"),
"static_var_types": np.array(["B", "B", "B", "B"], dtype="S"),
"static_var_categories": np.array( "static_var_categories": np.array(
["default", "", "default", "default"], ["default", "", "default", "default"],
dtype="S", dtype="S",

Loading…
Cancel
Save