Request constraint features/categories in bulk

This commit is contained in:
2021-06-29 09:54:35 -05:00
parent 8118ab4110
commit a5092cc2b9
6 changed files with 51 additions and 38 deletions

View File

@@ -6,7 +6,6 @@ from unittest.mock import Mock
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import MinProbabilityThreshold
@@ -42,21 +41,21 @@ def training_instances() -> List[Instance]:
instances[0].samples[1].after_load.instance.to_list = Mock( # type: ignore
return_value=[5.0]
)
instances[0].get_constraint_category = Mock( # type: ignore
side_effect=lambda cid: {
instances[0].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": "type-a",
"c2": "type-a",
"c3": "type-b",
"c4": "type-b",
}[cid]
}
)
instances[0].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: {
return_value={
"c1": [1.0, 2.0, 3.0],
"c2": [4.0, 5.0, 6.0],
"c3": [1.0, 2.0],
"c4": [3.0, 4.0],
}[cid]
}
)
instances[1].samples = [
Sample(
@@ -67,20 +66,20 @@ def training_instances() -> List[Instance]:
instances[1].samples[0].after_load.instance.to_list = Mock( # type: ignore
return_value=[8.0]
)
instances[1].get_constraint_category = Mock( # type: ignore
side_effect=lambda cid: {
instances[1].get_constraint_categories = Mock( # type: ignore
return_value={
"c1": None,
"c2": "type-a",
"c3": "type-b",
"c4": "type-b",
}[cid]
}
)
instances[1].get_constraint_features = Mock( # type: ignore
side_effect=lambda cid: {
return_value={
"c2": [7.0, 8.0, 9.0],
"c3": [5.0, 6.0],
"c4": [7.0, 8.0],
}[cid]
}
)
return instances