Primal: Use instance.features

master
Alinson S. Xavier 5 years ago
parent 12fca1f22b
commit 4f46866921

@ -117,15 +117,12 @@ class PrimalSolutionComponent(Component):
def predict(self, instance: Instance) -> Solution:
assert len(instance.training_data) > 0
sample = instance.training_data[-1]
assert "LP solution" in sample
lp_solution = sample["LP solution"]
assert lp_solution is not None
# Initialize empty solution
solution: Solution = {}
for (var_name, var_dict) in lp_solution.items():
for (var_name, var_dict) in instance.features["Variables"].items():
solution[var_name] = {}
for (idx, lp_value) in var_dict.items():
for idx in var_dict.keys():
solution[var_name][idx] = None
# Compute y_pred
@ -147,9 +144,9 @@ class PrimalSolutionComponent(Component):
# Convert y_pred into solution
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
for (var_name, var_dict) in lp_solution.items():
for (idx, lp_value) in var_dict.items():
category = instance.get_variable_category(var_name, idx)
for (var_name, var_dict) in instance.features["Variables"].items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
offset = category_offset[category]
category_offset[category] += 1
if y_pred[category][offset, 0]:
@ -211,10 +208,8 @@ class PrimalSolutionComponent(Component):
instance: Any,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
if "Solution" not in sample:
return x, y
return {}, {}
assert sample["Solution"] is not None
return cast(
Tuple[Dict, Dict],
@ -222,7 +217,6 @@ class PrimalSolutionComponent(Component):
instance,
sample,
sample["Solution"],
extract_y=True,
),
)
@ -231,51 +225,43 @@ class PrimalSolutionComponent(Component):
instance: Any,
sample: TrainingSample,
) -> Dict:
return cast(
Dict,
PrimalSolutionComponent._extract(
instance,
sample,
instance.features["Variables"],
extract_y=False,
),
)
return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
@staticmethod
def _extract(
instance: Any,
sample: TrainingSample,
variables: Dict,
extract_y: bool,
solution: Optional[Dict] = None,
) -> Union[Dict, Tuple[Dict, Dict]]:
x: Dict = {}
y: Dict = {}
for (var, var_dict) in variables.items():
for (idx, opt_value) in var_dict.items():
if extract_y:
opt_value = 0.0
for (var_name, var_dict) in instance.features["Variables"].items():
for (idx, var_features) in var_dict.items():
category = var_features["Category"]
if category is None:
continue
if solution is not None:
opt_value = solution[var_name][idx]
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var} has non-binary value {opt_value} in the "
f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its "
"category to None."
)
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
features: Any = instance.get_variable_features(var, idx)
assert isinstance(features, list)
features = var_features["User features"]
if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var][idx]
lp_value = sample["LP solution"][var_name][idx]
if lp_value is not None:
features += [sample["LP solution"][var][idx]]
features += [sample["LP solution"][var_name][idx]]
x[category] += [features]
if extract_y:
if solution is not None:
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if extract_y:
if solution is not None:
return x, y
else:
return x

@ -73,6 +73,15 @@ LearningSolveStats = TypedDict(
total=False,
)
VariableFeatures = TypedDict(
"VariableFeatures",
{
"Category": Optional[Hashable],
"User features": Optional[List[float]],
},
total=False,
)
ConstraintFeatures = TypedDict(
"ConstraintFeatures",
{
@ -88,7 +97,7 @@ ConstraintFeatures = TypedDict(
ModelFeatures = TypedDict(
"ModelFeatures",
{
"Variables": Solution,
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
"Constraints": Dict[str, ConstraintFeatures],
},
total=False,

@ -16,22 +16,27 @@ from miplearn.types import TrainingSample
def test_xy_sample_with_lp_solution() -> None:
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
side_effect=lambda var_name, index: {
0: "default",
1: None,
2: "default",
3: "default",
}[index]
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 1.0],
2: [1.0, 0.0],
3: [1.0, 1.0],
}[index]
)
instance.features = {
"Variables": {
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": None,
},
2: {
"Category": "default",
"User features": [1.0, 0.0],
},
3: {
"Category": "default",
"User features": [1.0, 1.0],
},
}
}
}
sample: TrainingSample = {
"Solution": {
"x": {
@ -78,22 +83,27 @@ def test_xy_sample_with_lp_solution() -> None:
def test_xy_sample_without_lp_solution() -> None:
comp = PrimalSolutionComponent()
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
side_effect=lambda var_name, index: {
0: "default",
1: None,
2: "default",
3: "default",
}[index]
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 1.0],
2: [1.0, 0.0],
3: [1.0, 1.0],
}[index]
)
instance.features = {
"Variables": {
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": None,
},
2: {
"Category": "default",
"User features": [1.0, 0.0],
},
3: {
"Category": "default",
"User features": [1.0, 1.0],
},
}
}
}
sample: TrainingSample = {
"Solution": {
"x": {
@ -143,22 +153,21 @@ def test_predict() -> None:
thr = Mock(spec=Threshold)
thr.predict = Mock(return_value=[0.75, 0.75])
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
return_value="default",
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 2.0],
2: [2.0, 0.0],
}[index]
)
instance.features = {
"Variables": {
"x": {
0: None,
1: None,
2: None,
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": "default",
"User features": [0.0, 2.0],
},
2: {
"Category": "default",
"User features": [2.0, 0.0],
},
}
}
}

Loading…
Cancel
Save