From d79eec5da652045018813a74827222157fd64d58 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sun, 4 Apr 2021 22:56:26 -0500 Subject: [PATCH] Convert VariableFeatures into dataclass --- miplearn/components/primal.py | 8 +-- miplearn/features.py | 31 +++++++---- miplearn/solvers/internal.py | 2 +- miplearn/types.py | 14 +++-- tests/components/test_primal.py | 96 ++++++++++++++++----------------- tests/test_features.py | 33 ++++++------ 6 files changed, 97 insertions(+), 87 deletions(-) diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 13a49b8..eaa6a75 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -136,7 +136,7 @@ class PrimalSolutionComponent(Component): category_offset: Dict[Hashable, int] = {cat: 0 for cat in x.keys()} for (var_name, var_dict) in features.variables.items(): for (idx, var_features) in var_dict.items(): - category = var_features["Category"] + category = var_features.category offset = category_offset[category] category_offset[category] += 1 if y_pred[category][offset, 0]: @@ -159,15 +159,15 @@ class PrimalSolutionComponent(Component): solution = sample["Solution"] for (var_name, var_dict) in features.variables.items(): for (idx, var_features) in var_dict.items(): - category = var_features["Category"] + category = var_features.category if category is None: continue if category not in x.keys(): x[category] = [] y[category] = [] f: List[float] = [] - assert var_features["User features"] is not None - f += var_features["User features"] + assert var_features.user_features is not None + f += var_features.user_features if "LP solution" in sample and sample["LP solution"] is not None: lp_value = sample["LP solution"][var_name][idx] if lp_value is not None: diff --git a/miplearn/features.py b/miplearn/features.py index c401f25..5b06bce 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -4,9 +4,15 @@ import numbers import collections -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Hashable -from miplearn.types import Features, ConstraintFeatures, InstanceFeatures +from miplearn.types import ( + Features, + ConstraintFeatures, + InstanceFeatures, + VariableFeatures, + VarIndex, +) if TYPE_CHECKING: from miplearn import InternalSolver, Instance @@ -24,9 +30,14 @@ class FeaturesExtractor: instance.features.constraints = self._extract_constraints(instance) instance.features.instance = self._extract_instance(instance, instance.features) - def _extract_variables(self, instance: "Instance") -> Dict: - variables = self.solver.get_empty_solution() - for (var_name, var_dict) in variables.items(): + def _extract_variables( + self, + instance: "Instance", + ) -> Dict[str, Dict[VarIndex, VariableFeatures]]: + result: Dict[str, Dict[VarIndex, VariableFeatures]] = {} + empty_solution = self.solver.get_empty_solution() + for (var_name, var_dict) in empty_solution.items(): + result[var_name] = {} for idx in var_dict.keys(): user_features = None category = instance.get_variable_category(var_name, idx) @@ -47,11 +58,11 @@ class FeaturesExtractor: f"Found {type(v).__name__} instead " f"for var={var_name}[{idx}]." ) - var_dict[idx] = { - "Category": category, - "User features": user_features, - } - return variables + result[var_name][idx] = VariableFeatures( + category=category, + user_features=user_features, + ) + return result def _extract_constraints( self, diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index 48a352a..20c798f 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -274,7 +274,7 @@ class InternalSolver(ABC): pass @abstractmethod - def get_empty_solution(self) -> Dict: + def get_empty_solution(self) -> Dict[str, Dict[VarIndex, Optional[float]]]: """ Returns a dictionary with the same shape as the one produced by `get_solution`, but with all values set to None. This method is diff --git a/miplearn/types.py b/miplearn/types.py index 9c5c9f4..ab6a548 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -87,14 +87,12 @@ InstanceFeatures = TypedDict( total=False, ) -VariableFeatures = TypedDict( - "VariableFeatures", - { - "Category": Optional[Hashable], - "User features": Optional[List[float]], - }, - total=False, -) + +@dataclass +class VariableFeatures: + category: Optional[Hashable] = None + user_features: Optional[List[float]] = None + ConstraintFeatures = TypedDict( "ConstraintFeatures", diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index 1613442..2c2ee8f 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -13,28 +13,28 @@ from miplearn.classifiers.threshold import Threshold from miplearn.components import classifier_evaluation_dict from miplearn.components.primal import PrimalSolutionComponent from miplearn.problems.tsp import TravelingSalesmanGenerator -from miplearn.types import TrainingSample, Features +from miplearn.types import TrainingSample, Features, VariableFeatures def test_xy() -> None: features = Features( variables={ "x": { - 0: { - "Category": "default", - "User features": [0.0, 0.0], - }, - 1: { - "Category": None, - }, - 2: { - "Category": "default", - "User features": [1.0, 0.0], - }, - 3: { - "Category": "default", - "User features": [1.0, 1.0], - }, + 0: VariableFeatures( + category="default", + user_features=[0.0, 0.0], + ), + 1: VariableFeatures( + category=None, + ), + 2: VariableFeatures( + category="default", + user_features=[1.0, 0.0], + ), + 3: VariableFeatures( + category="default", + user_features=[1.0, 1.0], + ), } } ) @@ -81,21 +81,21 @@ def test_xy_without_lp_solution() -> None: features = Features( variables={ "x": { - 0: { - "Category": "default", - "User features": [0.0, 0.0], - }, - 1: { - "Category": None, - }, - 2: { - "Category": "default", - "User features": [1.0, 0.0], - }, - 3: { - "Category": "default", - "User features": [1.0, 1.0], - }, + 0: VariableFeatures( + category="default", + user_features=[0.0, 0.0], + ), + 1: VariableFeatures( + category=None, + ), + 2: VariableFeatures( + category="default", + user_features=[1.0, 0.0], + ), + 3: VariableFeatures( + category="default", + user_features=[1.0, 1.0], + ), } } ) @@ -146,18 +146,18 @@ def test_predict() -> None: features = Features( variables={ "x": { - 0: { - "Category": "default", - "User features": [0.0, 0.0], - }, - 1: { - "Category": "default", - "User features": [0.0, 2.0], - }, - 2: { - "Category": "default", - "User features": [2.0, 0.0], - }, + 0: VariableFeatures( + category="default", + user_features=[0.0, 0.0], + ), + 1: VariableFeatures( + category="default", + user_features=[0.0, 2.0], + ), + 2: VariableFeatures( + category="default", + user_features=[2.0, 0.0], + ), } } ) @@ -246,11 +246,11 @@ def test_evaluate() -> None: features = Features( variables={ "x": { - 0: {}, - 1: {}, - 2: {}, - 3: {}, - 4: {}, + 0: VariableFeatures(), + 1: VariableFeatures(), + 2: VariableFeatures(), + 3: VariableFeatures(), + 4: VariableFeatures(), } } ) diff --git a/tests/test_features.py b/tests/test_features.py index 5103096..ba1546f 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -4,6 +4,7 @@ from miplearn import GurobiSolver from miplearn.features import FeaturesExtractor +from miplearn.types import VariableFeatures from tests.fixtures.knapsack import get_knapsack_instance @@ -16,22 +17,22 @@ def test_knapsack() -> None: FeaturesExtractor(solver).extract(instance) assert instance.features.variables == { "x": { - 0: { - "Category": "default", - "User features": [23.0, 505.0], - }, - 1: { - "Category": "default", - "User features": [26.0, 352.0], - }, - 2: { - "Category": "default", - "User features": [20.0, 458.0], - }, - 3: { - "Category": "default", - "User features": [18.0, 220.0], - }, + 0: VariableFeatures( + category="default", + user_features=[23.0, 505.0], + ), + 1: VariableFeatures( + category="default", + user_features=[26.0, 352.0], + ), + 2: VariableFeatures( + category="default", + user_features=[20.0, 458.0], + ), + 3: VariableFeatures( + category="default", + user_features=[18.0, 220.0], + ), } } assert instance.features.constraints == {