Convert VariableFeatures into dataclass

This commit is contained in:
2021-04-04 22:56:26 -05:00
parent 59f4f75a53
commit d79eec5da6
6 changed files with 97 additions and 87 deletions

View File

@@ -13,28 +13,28 @@ from miplearn.classifiers.threshold import Threshold
from miplearn.components import classifier_evaluation_dict
from miplearn.components.primal import PrimalSolutionComponent
from miplearn.problems.tsp import TravelingSalesmanGenerator
from miplearn.types import TrainingSample, Features
from miplearn.types import TrainingSample, Features, VariableFeatures
def test_xy() -> None:
features = Features(
variables={
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": None,
},
2: {
"Category": "default",
"User features": [1.0, 0.0],
},
3: {
"Category": "default",
"User features": [1.0, 1.0],
},
0: VariableFeatures(
category="default",
user_features=[0.0, 0.0],
),
1: VariableFeatures(
category=None,
),
2: VariableFeatures(
category="default",
user_features=[1.0, 0.0],
),
3: VariableFeatures(
category="default",
user_features=[1.0, 1.0],
),
}
}
)
@@ -81,21 +81,21 @@ def test_xy_without_lp_solution() -> None:
features = Features(
variables={
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": None,
},
2: {
"Category": "default",
"User features": [1.0, 0.0],
},
3: {
"Category": "default",
"User features": [1.0, 1.0],
},
0: VariableFeatures(
category="default",
user_features=[0.0, 0.0],
),
1: VariableFeatures(
category=None,
),
2: VariableFeatures(
category="default",
user_features=[1.0, 0.0],
),
3: VariableFeatures(
category="default",
user_features=[1.0, 1.0],
),
}
}
)
@@ -146,18 +146,18 @@ def test_predict() -> None:
features = Features(
variables={
"x": {
0: {
"Category": "default",
"User features": [0.0, 0.0],
},
1: {
"Category": "default",
"User features": [0.0, 2.0],
},
2: {
"Category": "default",
"User features": [2.0, 0.0],
},
0: VariableFeatures(
category="default",
user_features=[0.0, 0.0],
),
1: VariableFeatures(
category="default",
user_features=[0.0, 2.0],
),
2: VariableFeatures(
category="default",
user_features=[2.0, 0.0],
),
}
}
)
@@ -246,11 +246,11 @@ def test_evaluate() -> None:
features = Features(
variables={
"x": {
0: {},
1: {},
2: {},
3: {},
4: {},
0: VariableFeatures(),
1: VariableFeatures(),
2: VariableFeatures(),
3: VariableFeatures(),
4: VariableFeatures(),
}
}
)

View File

@@ -4,6 +4,7 @@
from miplearn import GurobiSolver
from miplearn.features import FeaturesExtractor
from miplearn.types import VariableFeatures
from tests.fixtures.knapsack import get_knapsack_instance
@@ -16,22 +17,22 @@ def test_knapsack() -> None:
FeaturesExtractor(solver).extract(instance)
assert instance.features.variables == {
"x": {
0: {
"Category": "default",
"User features": [23.0, 505.0],
},
1: {
"Category": "default",
"User features": [26.0, 352.0],
},
2: {
"Category": "default",
"User features": [20.0, 458.0],
},
3: {
"Category": "default",
"User features": [18.0, 220.0],
},
0: VariableFeatures(
category="default",
user_features=[23.0, 505.0],
),
1: VariableFeatures(
category="default",
user_features=[26.0, 352.0],
),
2: VariableFeatures(
category="default",
user_features=[20.0, 458.0],
),
3: VariableFeatures(
category="default",
user_features=[18.0, 220.0],
),
}
}
assert instance.features.constraints == {