Extract all features ahead of time

This commit is contained in:
2021-03-31 07:42:01 -05:00
parent b3c24814b0
commit 12fca1f22b
8 changed files with 106 additions and 39 deletions

View File

@@ -153,7 +153,7 @@ def test_predict() -> None:
2: [2.0, 0.0],
}[index]
)
instance.model_features = {
instance.features = {
"Variables": {
"x": {
0: None,

View File

@@ -27,7 +27,7 @@ def test_learning_solver():
solver.solve(instance)
assert hasattr(instance, "model_features")
assert hasattr(instance, "features")
data = instance.training_data[0]
assert data["Solution"]["x"][0] == 1.0

View File

@@ -3,37 +3,47 @@
# Released under the modified BSD license. See COPYING.md for more details.
from miplearn import GurobiSolver
from miplearn.features import ModelFeaturesExtractor
from miplearn.features import FeaturesExtractor
from tests.fixtures.knapsack import get_knapsack_instance
def test_knapsack() -> None:
for solver_factory in [GurobiSolver]:
# Initialize model, instance and internal solver
solver = solver_factory()
instance = get_knapsack_instance(solver)
model = instance.to_model()
solver.set_instance(instance, model)
# Extract all model features
extractor = ModelFeaturesExtractor(solver)
features = extractor.extract()
# Test constraint features
print(solver, features)
extractor = FeaturesExtractor(solver)
features = extractor.extract(instance)
assert features["Variables"] == {
"x": {
0: None,
1: None,
2: None,
3: None,
0: {
"Category": "default",
"User features": [23.0, 505.0],
},
1: {
"Category": "default",
"User features": [26.0, 352.0],
},
2: {
"Category": "default",
"User features": [20.0, 458.0],
},
3: {
"Category": "default",
"User features": [18.0, 220.0],
},
}
}
assert features["Constraints"]["eq_capacity"]["LHS"] == {
"x[0]": 23.0,
"x[1]": 26.0,
"x[2]": 20.0,
"x[3]": 18.0,
assert features["Constraints"]["eq_capacity"] == {
"LHS": {
"x[0]": 23.0,
"x[1]": 26.0,
"x[2]": 20.0,
"x[3]": 18.0,
},
"Sense": "<",
"RHS": 67.0,
"Category": "eq_capacity",
"User features": [0.0],
}
assert features["Constraints"]["eq_capacity"]["Sense"] == "<"
assert features["Constraints"]["eq_capacity"]["RHS"] == 67.0