Refactor primal

This commit is contained in:
2021-03-30 21:44:13 -05:00
parent 9cf28f3cdc
commit ec69464794
2 changed files with 46 additions and 98 deletions

View File

@@ -15,7 +15,6 @@ from miplearn.types import TrainingSample
def test_xy_sample_with_lp_solution() -> None:
comp = PrimalSolutionComponent()
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
side_effect=lambda var_name, index: {
@@ -131,8 +130,6 @@ def test_xy_sample_without_lp_solution() -> None:
def test_predict() -> None:
comp = PrimalSolutionComponent()
clf = Mock(spec=Classifier)
clf.predict_proba = Mock(
return_value=np.array(
@@ -143,12 +140,8 @@ def test_predict() -> None:
]
)
)
comp.classifiers = {"default": clf}
thr = Mock(spec=Threshold)
thr.predict = Mock(return_value=[0.75, 0.75])
comp.thresholds = {"default": thr}
instance = cast(Instance, Mock(spec=Instance))
instance.get_variable_category = Mock( # type: ignore
return_value="default",
@@ -160,6 +153,15 @@ def test_predict() -> None:
2: [2.0, 0.0],
}[index]
)
instance.model_features = {
"Variables": {
"x": {
0: None,
1: None,
2: None,
}
}
}
instance.training_data = [
{
"LP solution": {
@@ -171,16 +173,23 @@ def test_predict() -> None:
}
}
]
x = comp.x([instance])
x = {
"default": np.array(
[
[0.0, 0.0, 0.1],
[0.0, 2.0, 0.5],
[2.0, 0.0, 0.9],
]
)
}
comp = PrimalSolutionComponent()
comp.classifiers = {"default": clf}
comp.thresholds = {"default": thr}
solution_actual = comp.predict(instance)
# Should ask for probabilities and thresholds
clf.predict_proba.assert_called_once()
thr.predict.assert_called_once()
assert_array_equal(x["default"], clf.predict_proba.call_args[0][0])
assert_array_equal(x["default"], thr.predict.call_args[0][0])
assert solution_actual == {
"x": {
0: 0.0,