Convert InstanceFeatures into dataclass

master
Alinson S. Xavier 5 years ago
parent d79eec5da6
commit 94084e0669
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -64,7 +64,7 @@ class StaticLazyConstraintsComponent(Component):
assert features.instance is not None
assert features.constraints is not None
if not features.instance["Lazy constraint count"] == 0:
if not features.instance.lazy_constraint_count == 0:
logger.info("Instance does not have static lazy constraints. Skipping.")
logger.info("Predicting required lazy constraints...")
self.enforced_cids = set(self.sample_predict(features, training_data))

@ -80,7 +80,7 @@ class ObjectiveValueComponent(Component):
assert features.instance is not None
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance["User features"])
f = list(features.instance.user_features)
if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]]
for c in ["Upper bound", "Lower bound"]:

@ -120,7 +120,7 @@ class FeaturesExtractor:
for (cid, cdict) in features.constraints.items():
if cdict["Lazy"]:
lazy_count += 1
return {
"User features": user_features,
"Lazy constraint count": lazy_count,
}
return InstanceFeatures(
user_features=user_features,
lazy_constraint_count=lazy_count,
)

@ -78,14 +78,11 @@ LearningSolveStats = TypedDict(
total=False,
)
InstanceFeatures = TypedDict(
"InstanceFeatures",
{
"User features": List[float],
"Lazy constraint count": int,
},
total=False,
)
@dataclass
class InstanceFeatures:
user_features: List[float]
lazy_constraint_count: int = 0
@dataclass

@ -12,7 +12,12 @@ from miplearn import LearningSolver, InternalSolver, Instance
from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
from miplearn.components.lazy_static import StaticLazyConstraintsComponent
from miplearn.types import TrainingSample, Features, LearningSolveStats
from miplearn.types import (
TrainingSample,
Features,
LearningSolveStats,
InstanceFeatures,
)
@pytest.fixture
@ -25,9 +30,10 @@ def sample() -> TrainingSample:
@pytest.fixture
def features() -> Features:
return Features(
instance={
"Lazy constraint count": 4,
},
instance=InstanceFeatures(
user_features=[0],
lazy_constraint_count=4,
),
constraints={
"c1": {
"Category": "type-a",

@ -9,7 +9,7 @@ from numpy.testing import assert_array_equal
from miplearn import GurobiPyomoSolver, LearningSolver, Regressor
from miplearn.components.objective import ObjectiveValueComponent
from miplearn.types import TrainingSample, Features
from miplearn.types import TrainingSample, Features, InstanceFeatures
from tests.fixtures.knapsack import get_knapsack_instance
import numpy as np
@ -18,9 +18,9 @@ import numpy as np
@pytest.fixture
def features() -> Features:
return Features(
instance={
"User features": [1.0, 2.0],
}
instance=InstanceFeatures(
user_features=[1.0, 2.0],
)
)

@ -4,7 +4,7 @@
from miplearn import GurobiSolver
from miplearn.features import FeaturesExtractor
from miplearn.types import VariableFeatures
from miplearn.types import VariableFeatures, InstanceFeatures
from tests.fixtures.knapsack import get_knapsack_instance
@ -50,7 +50,7 @@ def test_knapsack() -> None:
"User features": [0.0],
}
}
assert instance.features.instance == {
"User features": [67.0, 21.75],
"Lazy constraint count": 0,
}
assert instance.features.instance == InstanceFeatures(
user_features=[67.0, 21.75],
lazy_constraint_count=0,
)

Loading…
Cancel
Save