mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Primal: Use instance.features
This commit is contained in:
@@ -117,15 +117,12 @@ class PrimalSolutionComponent(Component):
|
||||
def predict(self, instance: Instance) -> Solution:
|
||||
assert len(instance.training_data) > 0
|
||||
sample = instance.training_data[-1]
|
||||
assert "LP solution" in sample
|
||||
lp_solution = sample["LP solution"]
|
||||
assert lp_solution is not None
|
||||
|
||||
# Initialize empty solution
|
||||
solution: Solution = {}
|
||||
for (var_name, var_dict) in lp_solution.items():
|
||||
for (var_name, var_dict) in instance.features["Variables"].items():
|
||||
solution[var_name] = {}
|
||||
for (idx, lp_value) in var_dict.items():
|
||||
for idx in var_dict.keys():
|
||||
solution[var_name][idx] = None
|
||||
|
||||
# Compute y_pred
|
||||
@@ -147,9 +144,9 @@ class PrimalSolutionComponent(Component):
|
||||
|
||||
# Convert y_pred into solution
|
||||
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
|
||||
for (var_name, var_dict) in lp_solution.items():
|
||||
for (idx, lp_value) in var_dict.items():
|
||||
category = instance.get_variable_category(var_name, idx)
|
||||
for (var_name, var_dict) in instance.features["Variables"].items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features["Category"]
|
||||
offset = category_offset[category]
|
||||
category_offset[category] += 1
|
||||
if y_pred[category][offset, 0]:
|
||||
@@ -211,10 +208,8 @@ class PrimalSolutionComponent(Component):
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
if "Solution" not in sample:
|
||||
return x, y
|
||||
return {}, {}
|
||||
assert sample["Solution"] is not None
|
||||
return cast(
|
||||
Tuple[Dict, Dict],
|
||||
@@ -222,7 +217,6 @@ class PrimalSolutionComponent(Component):
|
||||
instance,
|
||||
sample,
|
||||
sample["Solution"],
|
||||
extract_y=True,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -231,51 +225,43 @@ class PrimalSolutionComponent(Component):
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
) -> Dict:
|
||||
return cast(
|
||||
Dict,
|
||||
PrimalSolutionComponent._extract(
|
||||
instance,
|
||||
sample,
|
||||
instance.features["Variables"],
|
||||
extract_y=False,
|
||||
),
|
||||
)
|
||||
return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
|
||||
|
||||
@staticmethod
|
||||
def _extract(
|
||||
instance: Any,
|
||||
sample: TrainingSample,
|
||||
variables: Dict,
|
||||
extract_y: bool,
|
||||
solution: Optional[Dict] = None,
|
||||
) -> Union[Dict, Tuple[Dict, Dict]]:
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
for (var, var_dict) in variables.items():
|
||||
for (idx, opt_value) in var_dict.items():
|
||||
if extract_y:
|
||||
opt_value = 0.0
|
||||
for (var_name, var_dict) in instance.features["Variables"].items():
|
||||
for (idx, var_features) in var_dict.items():
|
||||
category = var_features["Category"]
|
||||
if category is None:
|
||||
continue
|
||||
if solution is not None:
|
||||
opt_value = solution[var_name][idx]
|
||||
assert opt_value is not None
|
||||
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
|
||||
f"Variable {var} has non-binary value {opt_value} in the "
|
||||
f"Variable {var_name} has non-binary value {opt_value} in the "
|
||||
"optimal solution. Predicting values of non-binary "
|
||||
"variables is not currently supported. Please set its "
|
||||
"category to None."
|
||||
)
|
||||
category = instance.get_variable_category(var, idx)
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x.keys():
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
features: Any = instance.get_variable_features(var, idx)
|
||||
assert isinstance(features, list)
|
||||
features = var_features["User features"]
|
||||
if "LP solution" in sample and sample["LP solution"] is not None:
|
||||
lp_value = sample["LP solution"][var][idx]
|
||||
lp_value = sample["LP solution"][var_name][idx]
|
||||
if lp_value is not None:
|
||||
features += [sample["LP solution"][var][idx]]
|
||||
features += [sample["LP solution"][var_name][idx]]
|
||||
x[category] += [features]
|
||||
if extract_y:
|
||||
if solution is not None:
|
||||
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
|
||||
if extract_y:
|
||||
if solution is not None:
|
||||
return x, y
|
||||
else:
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user