From 94084e0669516c197ccd041cb55de6c14ff25e05 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Mon, 5 Apr 2021 20:02:24 -0500 Subject: [PATCH] Convert InstanceFeatures into dataclass --- miplearn/components/lazy_static.py | 2 +- miplearn/components/objective.py | 2 +- miplearn/features.py | 8 ++++---- miplearn/types.py | 13 +++++-------- tests/components/test_lazy_static.py | 14 ++++++++++---- tests/components/test_objective.py | 8 ++++---- tests/test_features.py | 10 +++++----- 7 files changed, 30 insertions(+), 27 deletions(-) diff --git a/miplearn/components/lazy_static.py b/miplearn/components/lazy_static.py index 37732b8..82ce0ff 100644 --- a/miplearn/components/lazy_static.py +++ b/miplearn/components/lazy_static.py @@ -64,7 +64,7 @@ class StaticLazyConstraintsComponent(Component): assert features.instance is not None assert features.constraints is not None - if not features.instance["Lazy constraint count"] == 0: + if not features.instance.lazy_constraint_count == 0: logger.info("Instance does not have static lazy constraints. Skipping.") logger.info("Predicting required lazy constraints...") self.enforced_cids = set(self.sample_predict(features, training_data)) diff --git a/miplearn/components/objective.py b/miplearn/components/objective.py index fb05e90..f77da1e 100644 --- a/miplearn/components/objective.py +++ b/miplearn/components/objective.py @@ -80,7 +80,7 @@ class ObjectiveValueComponent(Component): assert features.instance is not None x: Dict[Hashable, List[List[float]]] = {} y: Dict[Hashable, List[List[float]]] = {} - f = list(features.instance["User features"]) + f = list(features.instance.user_features) if "LP value" in sample and sample["LP value"] is not None: f += [sample["LP value"]] for c in ["Upper bound", "Lower bound"]: diff --git a/miplearn/features.py b/miplearn/features.py index 5b06bce..3cb2006 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -120,7 +120,7 @@ class FeaturesExtractor: for (cid, cdict) in features.constraints.items(): if cdict["Lazy"]: lazy_count += 1 - return { - "User features": user_features, - "Lazy constraint count": lazy_count, - } + return InstanceFeatures( + user_features=user_features, + lazy_constraint_count=lazy_count, + ) diff --git a/miplearn/types.py b/miplearn/types.py index ab6a548..0809ffe 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -78,14 +78,11 @@ LearningSolveStats = TypedDict( total=False, ) -InstanceFeatures = TypedDict( - "InstanceFeatures", - { - "User features": List[float], - "Lazy constraint count": int, - }, - total=False, -) + +@dataclass +class InstanceFeatures: + user_features: List[float] + lazy_constraint_count: int = 0 @dataclass diff --git a/tests/components/test_lazy_static.py b/tests/components/test_lazy_static.py index 3a633cb..354b68b 100644 --- a/tests/components/test_lazy_static.py +++ b/tests/components/test_lazy_static.py @@ -12,7 +12,12 @@ from miplearn import LearningSolver, InternalSolver, Instance from miplearn.classifiers import Classifier from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold from miplearn.components.lazy_static import StaticLazyConstraintsComponent -from miplearn.types import TrainingSample, Features, LearningSolveStats +from miplearn.types import ( + TrainingSample, + Features, + LearningSolveStats, + InstanceFeatures, +) @pytest.fixture @@ -25,9 +30,10 @@ def sample() -> TrainingSample: @pytest.fixture def features() -> Features: return Features( - instance={ - "Lazy constraint count": 4, - }, + instance=InstanceFeatures( + user_features=[0], + lazy_constraint_count=4, + ), constraints={ "c1": { "Category": "type-a", diff --git a/tests/components/test_objective.py b/tests/components/test_objective.py index 1c4c3bf..d1beeb0 100644 --- a/tests/components/test_objective.py +++ b/tests/components/test_objective.py @@ -9,7 +9,7 @@ from numpy.testing import assert_array_equal from miplearn import GurobiPyomoSolver, LearningSolver, Regressor from miplearn.components.objective import ObjectiveValueComponent -from miplearn.types import TrainingSample, Features +from miplearn.types import TrainingSample, Features, InstanceFeatures from tests.fixtures.knapsack import get_knapsack_instance import numpy as np @@ -18,9 +18,9 @@ import numpy as np @pytest.fixture def features() -> Features: return Features( - instance={ - "User features": [1.0, 2.0], - } + instance=InstanceFeatures( + user_features=[1.0, 2.0], + ) ) diff --git a/tests/test_features.py b/tests/test_features.py index ba1546f..f035860 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -4,7 +4,7 @@ from miplearn import GurobiSolver from miplearn.features import FeaturesExtractor -from miplearn.types import VariableFeatures +from miplearn.types import VariableFeatures, InstanceFeatures from tests.fixtures.knapsack import get_knapsack_instance @@ -50,7 +50,7 @@ def test_knapsack() -> None: "User features": [0.0], } } - assert instance.features.instance == { - "User features": [67.0, 21.75], - "Lazy constraint count": 0, - } + assert instance.features.instance == InstanceFeatures( + user_features=[67.0, 21.75], + lazy_constraint_count=0, + )