mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Convert VariableFeatures into dataclass
This commit is contained in:
@@ -13,28 +13,28 @@ from miplearn.classifiers.threshold import Threshold
|
||||
from miplearn.components import classifier_evaluation_dict
|
||||
from miplearn.components.primal import PrimalSolutionComponent
|
||||
from miplearn.problems.tsp import TravelingSalesmanGenerator
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from miplearn.types import TrainingSample, Features, VariableFeatures
|
||||
|
||||
|
||||
def test_xy() -> None:
|
||||
features = Features(
|
||||
variables={
|
||||
"x": {
|
||||
0: {
|
||||
"Category": "default",
|
||||
"User features": [0.0, 0.0],
|
||||
},
|
||||
1: {
|
||||
"Category": None,
|
||||
},
|
||||
2: {
|
||||
"Category": "default",
|
||||
"User features": [1.0, 0.0],
|
||||
},
|
||||
3: {
|
||||
"Category": "default",
|
||||
"User features": [1.0, 1.0],
|
||||
},
|
||||
0: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[0.0, 0.0],
|
||||
),
|
||||
1: VariableFeatures(
|
||||
category=None,
|
||||
),
|
||||
2: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[1.0, 0.0],
|
||||
),
|
||||
3: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[1.0, 1.0],
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -81,21 +81,21 @@ def test_xy_without_lp_solution() -> None:
|
||||
features = Features(
|
||||
variables={
|
||||
"x": {
|
||||
0: {
|
||||
"Category": "default",
|
||||
"User features": [0.0, 0.0],
|
||||
},
|
||||
1: {
|
||||
"Category": None,
|
||||
},
|
||||
2: {
|
||||
"Category": "default",
|
||||
"User features": [1.0, 0.0],
|
||||
},
|
||||
3: {
|
||||
"Category": "default",
|
||||
"User features": [1.0, 1.0],
|
||||
},
|
||||
0: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[0.0, 0.0],
|
||||
),
|
||||
1: VariableFeatures(
|
||||
category=None,
|
||||
),
|
||||
2: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[1.0, 0.0],
|
||||
),
|
||||
3: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[1.0, 1.0],
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -146,18 +146,18 @@ def test_predict() -> None:
|
||||
features = Features(
|
||||
variables={
|
||||
"x": {
|
||||
0: {
|
||||
"Category": "default",
|
||||
"User features": [0.0, 0.0],
|
||||
},
|
||||
1: {
|
||||
"Category": "default",
|
||||
"User features": [0.0, 2.0],
|
||||
},
|
||||
2: {
|
||||
"Category": "default",
|
||||
"User features": [2.0, 0.0],
|
||||
},
|
||||
0: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[0.0, 0.0],
|
||||
),
|
||||
1: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[0.0, 2.0],
|
||||
),
|
||||
2: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[2.0, 0.0],
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -246,11 +246,11 @@ def test_evaluate() -> None:
|
||||
features = Features(
|
||||
variables={
|
||||
"x": {
|
||||
0: {},
|
||||
1: {},
|
||||
2: {},
|
||||
3: {},
|
||||
4: {},
|
||||
0: VariableFeatures(),
|
||||
1: VariableFeatures(),
|
||||
2: VariableFeatures(),
|
||||
3: VariableFeatures(),
|
||||
4: VariableFeatures(),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from miplearn import GurobiSolver
|
||||
from miplearn.features import FeaturesExtractor
|
||||
from miplearn.types import VariableFeatures
|
||||
from tests.fixtures.knapsack import get_knapsack_instance
|
||||
|
||||
|
||||
@@ -16,22 +17,22 @@ def test_knapsack() -> None:
|
||||
FeaturesExtractor(solver).extract(instance)
|
||||
assert instance.features.variables == {
|
||||
"x": {
|
||||
0: {
|
||||
"Category": "default",
|
||||
"User features": [23.0, 505.0],
|
||||
},
|
||||
1: {
|
||||
"Category": "default",
|
||||
"User features": [26.0, 352.0],
|
||||
},
|
||||
2: {
|
||||
"Category": "default",
|
||||
"User features": [20.0, 458.0],
|
||||
},
|
||||
3: {
|
||||
"Category": "default",
|
||||
"User features": [18.0, 220.0],
|
||||
},
|
||||
0: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[23.0, 505.0],
|
||||
),
|
||||
1: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[26.0, 352.0],
|
||||
),
|
||||
2: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[20.0, 458.0],
|
||||
),
|
||||
3: VariableFeatures(
|
||||
category="default",
|
||||
user_features=[18.0, 220.0],
|
||||
),
|
||||
}
|
||||
}
|
||||
assert instance.features.constraints == {
|
||||
|
||||
Reference in New Issue
Block a user