Primal: Use instance.features

master
Alinson S. Xavier 5 years ago
parent 12fca1f22b
commit 4f46866921

@ -117,15 +117,12 @@ class PrimalSolutionComponent(Component):
def predict(self, instance: Instance) -> Solution: def predict(self, instance: Instance) -> Solution:
assert len(instance.training_data) > 0 assert len(instance.training_data) > 0
sample = instance.training_data[-1] sample = instance.training_data[-1]
assert "LP solution" in sample
lp_solution = sample["LP solution"]
assert lp_solution is not None
# Initialize empty solution # Initialize empty solution
solution: Solution = {} solution: Solution = {}
for (var_name, var_dict) in lp_solution.items(): for (var_name, var_dict) in instance.features["Variables"].items():
solution[var_name] = {} solution[var_name] = {}
for (idx, lp_value) in var_dict.items(): for idx in var_dict.keys():
solution[var_name][idx] = None solution[var_name][idx] = None
# Compute y_pred # Compute y_pred
@ -147,9 +144,9 @@ class PrimalSolutionComponent(Component):
# Convert y_pred into solution # Convert y_pred into solution
category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()} category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()}
for (var_name, var_dict) in lp_solution.items(): for (var_name, var_dict) in instance.features["Variables"].items():
for (idx, lp_value) in var_dict.items(): for (idx, var_features) in var_dict.items():
category = instance.get_variable_category(var_name, idx) category = var_features["Category"]
offset = category_offset[category] offset = category_offset[category]
category_offset[category] += 1 category_offset[category] += 1
if y_pred[category][offset, 0]: if y_pred[category][offset, 0]:
@ -211,10 +208,8 @@ class PrimalSolutionComponent(Component):
instance: Any, instance: Any,
sample: TrainingSample, sample: TrainingSample,
) -> Tuple[Dict, Dict]: ) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
if "Solution" not in sample: if "Solution" not in sample:
return x, y return {}, {}
assert sample["Solution"] is not None assert sample["Solution"] is not None
return cast( return cast(
Tuple[Dict, Dict], Tuple[Dict, Dict],
@ -222,7 +217,6 @@ class PrimalSolutionComponent(Component):
instance, instance,
sample, sample,
sample["Solution"], sample["Solution"],
extract_y=True,
), ),
) )
@ -231,51 +225,43 @@ class PrimalSolutionComponent(Component):
instance: Any, instance: Any,
sample: TrainingSample, sample: TrainingSample,
) -> Dict: ) -> Dict:
return cast( return cast(Dict, PrimalSolutionComponent._extract(instance, sample))
Dict,
PrimalSolutionComponent._extract(
instance,
sample,
instance.features["Variables"],
extract_y=False,
),
)
@staticmethod @staticmethod
def _extract( def _extract(
instance: Any, instance: Any,
sample: TrainingSample, sample: TrainingSample,
variables: Dict, solution: Optional[Dict] = None,
extract_y: bool,
) -> Union[Dict, Tuple[Dict, Dict]]: ) -> Union[Dict, Tuple[Dict, Dict]]:
x: Dict = {} x: Dict = {}
y: Dict = {} y: Dict = {}
for (var, var_dict) in variables.items(): opt_value = 0.0
for (idx, opt_value) in var_dict.items(): for (var_name, var_dict) in instance.features["Variables"].items():
if extract_y: for (idx, var_features) in var_dict.items():
category = var_features["Category"]
if category is None:
continue
if solution is not None:
opt_value = solution[var_name][idx]
assert opt_value is not None assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var} has non-binary value {opt_value} in the " f"Variable {var_name} has non-binary value {opt_value} in the "
"optimal solution. Predicting values of non-binary " "optimal solution. Predicting values of non-binary "
"variables is not currently supported. Please set its " "variables is not currently supported. Please set its "
"category to None." "category to None."
) )
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in x.keys(): if category not in x.keys():
x[category] = [] x[category] = []
y[category] = [] y[category] = []
features: Any = instance.get_variable_features(var, idx) features = var_features["User features"]
assert isinstance(features, list)
if "LP solution" in sample and sample["LP solution"] is not None: if "LP solution" in sample and sample["LP solution"] is not None:
lp_value = sample["LP solution"][var][idx] lp_value = sample["LP solution"][var_name][idx]
if lp_value is not None: if lp_value is not None:
features += [sample["LP solution"][var][idx]] features += [sample["LP solution"][var_name][idx]]
x[category] += [features] x[category] += [features]
if extract_y: if solution is not None:
y[category] += [[opt_value < 0.5, opt_value >= 0.5]] y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
if extract_y: if solution is not None:
return x, y return x, y
else: else:
return x return x

@ -73,6 +73,15 @@ LearningSolveStats = TypedDict(
total=False, total=False,
) )
VariableFeatures = TypedDict(
"VariableFeatures",
{
"Category": Optional[Hashable],
"User features": Optional[List[float]],
},
total=False,
)
ConstraintFeatures = TypedDict( ConstraintFeatures = TypedDict(
"ConstraintFeatures", "ConstraintFeatures",
{ {
@ -88,7 +97,7 @@ ConstraintFeatures = TypedDict(
ModelFeatures = TypedDict( ModelFeatures = TypedDict(
"ModelFeatures", "ModelFeatures",
{ {
"Variables": Solution, "Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
"Constraints": Dict[str, ConstraintFeatures], "Constraints": Dict[str, ConstraintFeatures],
}, },
total=False, total=False,

@ -16,22 +16,27 @@ from miplearn.types import TrainingSample
def test_xy_sample_with_lp_solution() -> None: def test_xy_sample_with_lp_solution() -> None:
instance = cast(Instance, Mock(spec=Instance)) instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore instance.features = {
side_effect=lambda var_name, index: { "Variables": {
0: "default", "x": {
1: None, 0: {
2: "default", "Category": "default",
3: "default", "User features": [0.0, 0.0],
}[index] },
) 1: {
instance.get_variable_features = Mock( # type: ignore "Category": None,
side_effect=lambda var, index: { },
0: [0.0, 0.0], 2: {
1: [0.0, 1.0], "Category": "default",
2: [1.0, 0.0], "User features": [1.0, 0.0],
3: [1.0, 1.0], },
}[index] 3: {
) "Category": "default",
"User features": [1.0, 1.0],
},
}
}
}
sample: TrainingSample = { sample: TrainingSample = {
"Solution": { "Solution": {
"x": { "x": {
@ -78,22 +83,27 @@ def test_xy_sample_with_lp_solution() -> None:
def test_xy_sample_without_lp_solution() -> None: def test_xy_sample_without_lp_solution() -> None:
comp = PrimalSolutionComponent() comp = PrimalSolutionComponent()
instance = cast(Instance, Mock(spec=Instance)) instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore instance.features = {
side_effect=lambda var_name, index: { "Variables": {
0: "default", "x": {
1: None, 0: {
2: "default", "Category": "default",
3: "default", "User features": [0.0, 0.0],
}[index] },
) 1: {
instance.get_variable_features = Mock( # type: ignore "Category": None,
side_effect=lambda var, index: { },
0: [0.0, 0.0], 2: {
1: [0.0, 1.0], "Category": "default",
2: [1.0, 0.0], "User features": [1.0, 0.0],
3: [1.0, 1.0], },
}[index] 3: {
) "Category": "default",
"User features": [1.0, 1.0],
},
}
}
}
sample: TrainingSample = { sample: TrainingSample = {
"Solution": { "Solution": {
"x": { "x": {
@ -143,22 +153,21 @@ def test_predict() -> None:
thr = Mock(spec=Threshold) thr = Mock(spec=Threshold)
thr.predict = Mock(return_value=[0.75, 0.75]) thr.predict = Mock(return_value=[0.75, 0.75])
instance = cast(Instance, Mock(spec=Instance)) instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
return_value="default",
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 2.0],
2: [2.0, 0.0],
}[index]
)
instance.features = { instance.features = {
"Variables": { "Variables": {
"x": { "x": {
0: None, 0: {
1: None, "Category": "default",
2: None, "User features": [0.0, 0.0],
},
1: {
"Category": "default",
"User features": [0.0, 2.0],
},
2: {
"Category": "default",
"User features": [2.0, 0.0],
},
} }
} }
} }

Loading…
Cancel
Save