diff --git a/miplearn/features.py b/miplearn/features.py index d4ac84f..35a2607 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -6,7 +6,7 @@ import numbers import collections from typing import TYPE_CHECKING, Dict -from miplearn.types import Features, ConstraintFeatures +from miplearn.types import Features, ConstraintFeatures, InstanceFeatures if TYPE_CHECKING: from miplearn import InternalSolver, Instance @@ -21,6 +21,7 @@ class FeaturesExtractor: def extract(self, instance: "Instance") -> Features: return { + "Instance": self._extract_instance(instance), "Constraints": self._extract_constraints(instance), "Variables": self._extract_variables(instance), } @@ -83,3 +84,7 @@ class FeaturesExtractor: "User features": user_features, } return constraints + + @staticmethod + def _extract_instance(instance: "Instance") -> InstanceFeatures: + return {"User features": instance.get_instance_features()} diff --git a/miplearn/types.py b/miplearn/types.py index de32cb8..a19d1c9 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -73,6 +73,14 @@ LearningSolveStats = TypedDict( total=False, ) +InstanceFeatures = TypedDict( + "InstanceFeatures", + { + "User features": List[float], + }, + total=False, +) + VariableFeatures = TypedDict( "VariableFeatures", { @@ -97,6 +105,7 @@ ConstraintFeatures = TypedDict( Features = TypedDict( "Features", { + "Instance": InstanceFeatures, "Variables": Dict[str, Dict[VarIndex, VariableFeatures]], "Constraints": Dict[str, ConstraintFeatures], }, diff --git a/tests/test_features.py b/tests/test_features.py index e5be4c2..6947b35 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -47,3 +47,6 @@ def test_knapsack() -> None: "Category": "eq_capacity", "User features": [0.0], } + assert features["Instance"] == { + "User features": [67.0, 21.75], + }