Primal: Use instance.features

This commit is contained in:
2021-03-31 08:22:43 -05:00
parent 12fca1f22b
commit 4f46866921
3 changed files with 86 additions and 82 deletions

View File

@@ -16,22 +16,27 @@ from miplearn.types import TrainingSample
def test_xy_sample_with_lp_solution() -> None:
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
side_effect=lambda var_name, index: {
0: "default",
1: None,
2: "default",
3: "default",
}[index]
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 1.0],
2: [1.0, 0.0],
3: [1.0, 1.0],
}[index]
)
instance.features = {
"Variables": {
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": None,
},
2: {
"Category": "default",
"User features": [1.0, 0.0],
},
3: {
"Category": "default",
"User features": [1.0, 1.0],
},
}
}
}
sample: TrainingSample = {
"Solution": {
"x": {
@@ -78,22 +83,27 @@ def test_xy_sample_with_lp_solution() -> None:
def test_xy_sample_without_lp_solution() -> None:
comp = PrimalSolutionComponent()
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
side_effect=lambda var_name, index: {
0: "default",
1: None,
2: "default",
3: "default",
}[index]
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 1.0],
2: [1.0, 0.0],
3: [1.0, 1.0],
}[index]
)
instance.features = {
"Variables": {
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": None,
},
2: {
"Category": "default",
"User features": [1.0, 0.0],
},
3: {
"Category": "default",
"User features": [1.0, 1.0],
},
}
}
}
sample: TrainingSample = {
"Solution": {
"x": {
@@ -143,22 +153,21 @@ def test_predict() -> None:
thr = Mock(spec=Threshold)
thr.predict = Mock(return_value=[0.75, 0.75])
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
return_value="default",
)
instance.get_variable_features = Mock( # type: ignore
side_effect=lambda var, index: {
0: [0.0, 0.0],
1: [0.0, 2.0],
2: [2.0, 0.0],
}[index]
)
instance.features = {
"Variables": {
"x": {
0: None,
1: None,
2: None,
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": "default",
"User features": [0.0, 2.0],
},
2: {
"Category": "default",
"User features": [2.0, 0.0],
},
}
}
}