mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Add instance-level features to instance.features
This commit is contained in:
@@ -6,7 +6,7 @@ import numbers
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.types import Features, ConstraintFeatures
|
||||
from miplearn.types import Features, ConstraintFeatures, InstanceFeatures
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from miplearn import InternalSolver, Instance
|
||||
@@ -21,6 +21,7 @@ class FeaturesExtractor:
|
||||
|
||||
def extract(self, instance: "Instance") -> Features:
|
||||
return {
|
||||
"Instance": self._extract_instance(instance),
|
||||
"Constraints": self._extract_constraints(instance),
|
||||
"Variables": self._extract_variables(instance),
|
||||
}
|
||||
@@ -83,3 +84,7 @@ class FeaturesExtractor:
|
||||
"User features": user_features,
|
||||
}
|
||||
return constraints
|
||||
|
||||
@staticmethod
|
||||
def _extract_instance(instance: "Instance") -> InstanceFeatures:
|
||||
return {"User features": instance.get_instance_features()}
|
||||
|
||||
@@ -73,6 +73,14 @@ LearningSolveStats = TypedDict(
|
||||
total=False,
|
||||
)
|
||||
|
||||
InstanceFeatures = TypedDict(
|
||||
"InstanceFeatures",
|
||||
{
|
||||
"User features": List[float],
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
|
||||
VariableFeatures = TypedDict(
|
||||
"VariableFeatures",
|
||||
{
|
||||
@@ -97,6 +105,7 @@ ConstraintFeatures = TypedDict(
|
||||
Features = TypedDict(
|
||||
"Features",
|
||||
{
|
||||
"Instance": InstanceFeatures,
|
||||
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
|
||||
"Constraints": Dict[str, ConstraintFeatures],
|
||||
},
|
||||
|
||||
@@ -47,3 +47,6 @@ def test_knapsack() -> None:
|
||||
"Category": "eq_capacity",
|
||||
"User features": [0.0],
|
||||
}
|
||||
assert features["Instance"] == {
|
||||
"User features": [67.0, 21.75],
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user