Convert InstanceFeatures into dataclass

master
Alinson S. Xavier 5 years ago
parent d79eec5da6
commit 94084e0669
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -64,7 +64,7 @@ class StaticLazyConstraintsComponent(Component):
assert features.instance is not None assert features.instance is not None
assert features.constraints is not None assert features.constraints is not None
if not features.instance["Lazy constraint count"] == 0: if not features.instance.lazy_constraint_count == 0:
logger.info("Instance does not have static lazy constraints. Skipping.") logger.info("Instance does not have static lazy constraints. Skipping.")
logger.info("Predicting required lazy constraints...") logger.info("Predicting required lazy constraints...")
self.enforced_cids = set(self.sample_predict(features, training_data)) self.enforced_cids = set(self.sample_predict(features, training_data))

@ -80,7 +80,7 @@ class ObjectiveValueComponent(Component):
assert features.instance is not None assert features.instance is not None
x: Dict[Hashable, List[List[float]]] = {} x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {} y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance["User features"]) f = list(features.instance.user_features)
if "LP value" in sample and sample["LP value"] is not None: if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]] f += [sample["LP value"]]
for c in ["Upper bound", "Lower bound"]: for c in ["Upper bound", "Lower bound"]:

@ -120,7 +120,7 @@ class FeaturesExtractor:
for (cid, cdict) in features.constraints.items(): for (cid, cdict) in features.constraints.items():
if cdict["Lazy"]: if cdict["Lazy"]:
lazy_count += 1 lazy_count += 1
return { return InstanceFeatures(
"User features": user_features, user_features=user_features,
"Lazy constraint count": lazy_count, lazy_constraint_count=lazy_count,
} )

@ -78,14 +78,11 @@ LearningSolveStats = TypedDict(
total=False, total=False,
) )
InstanceFeatures = TypedDict(
"InstanceFeatures", @dataclass
{ class InstanceFeatures:
"User features": List[float], user_features: List[float]
"Lazy constraint count": int, lazy_constraint_count: int = 0
},
total=False,
)
@dataclass @dataclass

@ -12,7 +12,12 @@ from miplearn import LearningSolver, InternalSolver, Instance
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
from miplearn.components.lazy_static import StaticLazyConstraintsComponent from miplearn.components.lazy_static import StaticLazyConstraintsComponent
from miplearn.types import TrainingSample, Features, LearningSolveStats from miplearn.types import (
TrainingSample,
Features,
LearningSolveStats,
InstanceFeatures,
)
@pytest.fixture @pytest.fixture
@ -25,9 +30,10 @@ def sample() -> TrainingSample:
@pytest.fixture @pytest.fixture
def features() -> Features: def features() -> Features:
return Features( return Features(
instance={ instance=InstanceFeatures(
"Lazy constraint count": 4, user_features=[0],
}, lazy_constraint_count=4,
),
constraints={ constraints={
"c1": { "c1": {
"Category": "type-a", "Category": "type-a",

@ -9,7 +9,7 @@ from numpy.testing import assert_array_equal
from miplearn import GurobiPyomoSolver, LearningSolver, Regressor from miplearn import GurobiPyomoSolver, LearningSolver, Regressor
from miplearn.components.objective import ObjectiveValueComponent from miplearn.components.objective import ObjectiveValueComponent
from miplearn.types import TrainingSample, Features from miplearn.types import TrainingSample, Features, InstanceFeatures
from tests.fixtures.knapsack import get_knapsack_instance from tests.fixtures.knapsack import get_knapsack_instance
import numpy as np import numpy as np
@ -18,9 +18,9 @@ import numpy as np
@pytest.fixture @pytest.fixture
def features() -> Features: def features() -> Features:
return Features( return Features(
instance={ instance=InstanceFeatures(
"User features": [1.0, 2.0], user_features=[1.0, 2.0],
} )
) )

@ -4,7 +4,7 @@
from miplearn import GurobiSolver from miplearn import GurobiSolver
from miplearn.features import FeaturesExtractor from miplearn.features import FeaturesExtractor
from miplearn.types import VariableFeatures from miplearn.types import VariableFeatures, InstanceFeatures
from tests.fixtures.knapsack import get_knapsack_instance from tests.fixtures.knapsack import get_knapsack_instance
@ -50,7 +50,7 @@ def test_knapsack() -> None:
"User features": [0.0], "User features": [0.0],
} }
} }
assert instance.features.instance == { assert instance.features.instance == InstanceFeatures(
"User features": [67.0, 21.75], user_features=[67.0, 21.75],
"Lazy constraint count": 0, lazy_constraint_count=0,
} )

Loading…
Cancel
Save