mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 02:18:51 -06:00
Refer to variables by varname instead of (vname, index)
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Hashable,
|
||||
Optional,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
Tuple,
|
||||
@@ -23,8 +22,9 @@ from miplearn.components.component import Component
|
||||
from miplearn.features import TrainingSample, Features
|
||||
from miplearn.instance.base import Instance
|
||||
from miplearn.types import (
|
||||
Solution,
|
||||
LearningSolveStats,
|
||||
Category,
|
||||
Solution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -84,15 +84,14 @@ class PrimalSolutionComponent(Component):
|
||||
stats["Primal: Free"] = 0
|
||||
stats["Primal: Zero"] = 0
|
||||
stats["Primal: One"] = 0
|
||||
for (var, var_dict) in solution.items():
|
||||
for (idx, value) in var_dict.items():
|
||||
if value is None:
|
||||
stats["Primal: Free"] += 1
|
||||
for (var_name, value) in solution.items():
|
||||
if value is None:
|
||||
stats["Primal: Free"] += 1
|
||||
else:
|
||||
if value < 0.5:
|
||||
stats["Primal: Zero"] += 1
|
||||
else:
|
||||
if value < 0.5:
|
||||
stats["Primal: Zero"] += 1
|
||||
else:
|
||||
stats["Primal: One"] += 1
|
||||
stats["Primal: One"] += 1
|
||||
logger.info(
|
||||
f"Predicted: free: {stats['Primal: Free']}, "
|
||||
f"zero: {stats['Primal: Zero']}, "
|
||||
@@ -106,13 +105,6 @@ class PrimalSolutionComponent(Component):
|
||||
) -> Solution:
|
||||
assert instance.features.variables is not None
|
||||
|
||||
# Initialize empty solution
|
||||
solution: Solution = {}
|
||||
for (var_name, var_dict) in instance.features.variables.items():
|
||||
solution[var_name] = {}
|
||||
for idx in var_dict.keys():
|
||||
solution[var_name][idx] = None
|
||||
|
||||
# Compute y_pred
|
||||
x, _ = self.sample_xy(instance, sample)
|
||||
y_pred = {}
|
||||
@@ -132,56 +124,52 @@ class PrimalSolutionComponent(Component):
|
||||
).T
|
||||
|
||||
# Convert y_pred into solution
|
||||
solution: Solution = {v: None for v in instance.features.variables.keys()}
|
||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
||||
for (var_name, var_dict) in instance.features.variables.items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features.category
|
||||
offset = category_offset[category]
|
||||
category_offset[category] += 1
|
||||
if y_pred[category][offset, 0]:
|
||||
solution[var_name][idx] = 0.0
|
||||
if y_pred[category][offset, 1]:
|
||||
solution[var_name][idx] = 1.0
|
||||
for (var_name, var_features) in instance.features.variables.items():
|
||||
category = var_features.category
|
||||
offset = category_offset[category]
|
||||
category_offset[category] += 1
|
||||
if y_pred[category][offset, 0]:
|
||||
solution[var_name] = 0.0
|
||||
if y_pred[category][offset, 1]:
|
||||
solution[var_name] = 1.0
|
||||
|
||||
return solution
|
||||
|
||||
@staticmethod
|
||||
def sample_xy(
|
||||
self,
|
||||
instance: Instance,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]:
|
||||
assert instance.features.variables is not None
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
solution: Optional[Solution] = None
|
||||
if sample.solution is not None:
|
||||
solution = sample.solution
|
||||
for (var_name, var_dict) in instance.features.variables.items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features.category
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
f: List[float] = []
|
||||
assert var_features.user_features is not None
|
||||
f += var_features.user_features
|
||||
if sample.lp_solution is not None:
|
||||
lp_value = sample.lp_solution[var_name][idx]
|
||||
if lp_value is not None:
|
||||
f += [lp_value]
|
||||
x[category] += [f]
|
||||
if solution is not None:
|
||||
opt_value = solution[var_name][idx]
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var_name} has non-binary value {opt_value} in the "
|
||||
"optimal solution. Predicting values of non-binary "
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
)
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
for (var_name, var_features) in instance.features.variables.items():
|
||||
category = var_features.category
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
f: List[float] = []
|
||||
assert var_features.user_features is not None
|
||||
f += var_features.user_features
|
||||
if sample.lp_solution is not None:
|
||||
lp_value = sample.lp_solution[var_name]
|
||||
if lp_value is not None:
|
||||
f += [lp_value]
|
||||
x[category] += [f]
|
||||
if sample.solution is not None:
|
||||
opt_value = sample.solution[var_name]
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var_name} has non-binary value {opt_value} in the "
|
||||
"optimal solution. Predicting values of non-binary "
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
)
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
return x, y
|
||||
|
||||
def sample_evaluate(
|
||||
@@ -194,22 +182,19 @@ class PrimalSolutionComponent(Component):
|
||||
solution_pred = self.sample_predict(instance, sample)
|
||||
vars_all, vars_one, vars_zero = set(), set(), set()
|
||||
pred_one_positive, pred_zero_positive = set(), set()
|
||||
for (varname, var_dict) in solution_actual.items():
|
||||
if varname not in solution_pred.keys():
|
||||
continue
|
||||
for (idx, value_actual) in var_dict.items():
|
||||
assert value_actual is not None
|
||||
vars_all.add((varname, idx))
|
||||
if value_actual > 0.5:
|
||||
vars_one.add((varname, idx))
|
||||
for (var_name, value_actual) in solution_actual.items():
|
||||
assert value_actual is not None
|
||||
vars_all.add(var_name)
|
||||
if value_actual > 0.5:
|
||||
vars_one.add(var_name)
|
||||
else:
|
||||
vars_zero.add(var_name)
|
||||
value_pred = solution_pred[var_name]
|
||||
if value_pred is not None:
|
||||
if value_pred > 0.5:
|
||||
pred_one_positive.add(var_name)
|
||||
else:
|
||||
vars_zero.add((varname, idx))
|
||||
value_pred = solution_pred[varname][idx]
|
||||
if value_pred is not None:
|
||||
if value_pred > 0.5:
|
||||
pred_one_positive.add((varname, idx))
|
||||
else:
|
||||
pred_zero_positive.add((varname, idx))
|
||||
pred_zero_positive.add(var_name)
|
||||
pred_one_negative = vars_all - pred_one_positive
|
||||
pred_zero_negative = vars_all - pred_zero_positive
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user