Convert InstanceFeatures into dataclass

This commit is contained in:
2021-04-05 20:02:24 -05:00
parent d79eec5da6
commit 94084e0669
7 changed files with 30 additions and 27 deletions

View File

@@ -64,7 +64,7 @@ class StaticLazyConstraintsComponent(Component):
assert features.instance is not None assert features.instance is not None
assert features.constraints is not None assert features.constraints is not None
if not features.instance["Lazy constraint count"] == 0: if not features.instance.lazy_constraint_count == 0:
logger.info("Instance does not have static lazy constraints. Skipping.") logger.info("Instance does not have static lazy constraints. Skipping.")
logger.info("Predicting required lazy constraints...") logger.info("Predicting required lazy constraints...")
self.enforced_cids = set(self.sample_predict(features, training_data)) self.enforced_cids = set(self.sample_predict(features, training_data))

View File

@@ -80,7 +80,7 @@ class ObjectiveValueComponent(Component):
assert features.instance is not None assert features.instance is not None
x: Dict[Hashable, List[List[float]]] = {} x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {} y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance["User features"]) f = list(features.instance.user_features)
if "LP value" in sample and sample["LP value"] is not None: if "LP value" in sample and sample["LP value"] is not None:
f += [sample["LP value"]] f += [sample["LP value"]]
for c in ["Upper bound", "Lower bound"]: for c in ["Upper bound", "Lower bound"]:

View File

@@ -120,7 +120,7 @@ class FeaturesExtractor:
for (cid, cdict) in features.constraints.items(): for (cid, cdict) in features.constraints.items():
if cdict["Lazy"]: if cdict["Lazy"]:
lazy_count += 1 lazy_count += 1
return { return InstanceFeatures(
"User features": user_features, user_features=user_features,
"Lazy constraint count": lazy_count, lazy_constraint_count=lazy_count,
} )

View File

@@ -78,14 +78,11 @@ LearningSolveStats = TypedDict(
total=False, total=False,
) )
InstanceFeatures = TypedDict(
"InstanceFeatures", @dataclass
{ class InstanceFeatures:
"User features": List[float], user_features: List[float]
"Lazy constraint count": int, lazy_constraint_count: int = 0
},
total=False,
)
@dataclass @dataclass

View File

@@ -12,7 +12,12 @@ from miplearn import LearningSolver, InternalSolver, Instance
from miplearn.classifiers import Classifier from miplearn.classifiers import Classifier
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
from miplearn.components.lazy_static import StaticLazyConstraintsComponent from miplearn.components.lazy_static import StaticLazyConstraintsComponent
from miplearn.types import TrainingSample, Features, LearningSolveStats from miplearn.types import (
TrainingSample,
Features,
LearningSolveStats,
InstanceFeatures,
)
@pytest.fixture @pytest.fixture
@@ -25,9 +30,10 @@ def sample() -> TrainingSample:
@pytest.fixture @pytest.fixture
def features() -> Features: def features() -> Features:
return Features( return Features(
instance={ instance=InstanceFeatures(
"Lazy constraint count": 4, user_features=[0],
}, lazy_constraint_count=4,
),
constraints={ constraints={
"c1": { "c1": {
"Category": "type-a", "Category": "type-a",

View File

@@ -9,7 +9,7 @@ from numpy.testing import assert_array_equal
from miplearn import GurobiPyomoSolver, LearningSolver, Regressor from miplearn import GurobiPyomoSolver, LearningSolver, Regressor
from miplearn.components.objective import ObjectiveValueComponent from miplearn.components.objective import ObjectiveValueComponent
from miplearn.types import TrainingSample, Features from miplearn.types import TrainingSample, Features, InstanceFeatures
from tests.fixtures.knapsack import get_knapsack_instance from tests.fixtures.knapsack import get_knapsack_instance
import numpy as np import numpy as np
@@ -18,9 +18,9 @@ import numpy as np
@pytest.fixture @pytest.fixture
def features() -> Features: def features() -> Features:
return Features( return Features(
instance={ instance=InstanceFeatures(
"User features": [1.0, 2.0], user_features=[1.0, 2.0],
} )
) )

View File

@@ -4,7 +4,7 @@
from miplearn import GurobiSolver from miplearn import GurobiSolver
from miplearn.features import FeaturesExtractor from miplearn.features import FeaturesExtractor
from miplearn.types import VariableFeatures from miplearn.types import VariableFeatures, InstanceFeatures
from tests.fixtures.knapsack import get_knapsack_instance from tests.fixtures.knapsack import get_knapsack_instance
@@ -50,7 +50,7 @@ def test_knapsack() -> None:
"User features": [0.0], "User features": [0.0],
} }
} }
assert instance.features.instance == { assert instance.features.instance == InstanceFeatures(
"User features": [67.0, 21.75], user_features=[67.0, 21.75],
"Lazy constraint count": 0, lazy_constraint_count=0,
} )