mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Add instance-level features to instance.features
This commit is contained in:
@@ -6,7 +6,7 @@ import numbers
|
|||||||
import collections
|
import collections
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict
|
||||||
|
|
||||||
from miplearn.types import Features, ConstraintFeatures
|
from miplearn.types import Features, ConstraintFeatures, InstanceFeatures
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from miplearn import InternalSolver, Instance
|
from miplearn import InternalSolver, Instance
|
||||||
@@ -21,6 +21,7 @@ class FeaturesExtractor:
|
|||||||
|
|
||||||
def extract(self, instance: "Instance") -> Features:
|
def extract(self, instance: "Instance") -> Features:
|
||||||
return {
|
return {
|
||||||
|
"Instance": self._extract_instance(instance),
|
||||||
"Constraints": self._extract_constraints(instance),
|
"Constraints": self._extract_constraints(instance),
|
||||||
"Variables": self._extract_variables(instance),
|
"Variables": self._extract_variables(instance),
|
||||||
}
|
}
|
||||||
@@ -83,3 +84,7 @@ class FeaturesExtractor:
|
|||||||
"User features": user_features,
|
"User features": user_features,
|
||||||
}
|
}
|
||||||
return constraints
|
return constraints
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_instance(instance: "Instance") -> InstanceFeatures:
|
||||||
|
return {"User features": instance.get_instance_features()}
|
||||||
|
|||||||
@@ -73,6 +73,14 @@ LearningSolveStats = TypedDict(
|
|||||||
total=False,
|
total=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
InstanceFeatures = TypedDict(
|
||||||
|
"InstanceFeatures",
|
||||||
|
{
|
||||||
|
"User features": List[float],
|
||||||
|
},
|
||||||
|
total=False,
|
||||||
|
)
|
||||||
|
|
||||||
VariableFeatures = TypedDict(
|
VariableFeatures = TypedDict(
|
||||||
"VariableFeatures",
|
"VariableFeatures",
|
||||||
{
|
{
|
||||||
@@ -97,6 +105,7 @@ ConstraintFeatures = TypedDict(
|
|||||||
Features = TypedDict(
|
Features = TypedDict(
|
||||||
"Features",
|
"Features",
|
||||||
{
|
{
|
||||||
|
"Instance": InstanceFeatures,
|
||||||
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
|
"Variables": Dict[str, Dict[VarIndex, VariableFeatures]],
|
||||||
"Constraints": Dict[str, ConstraintFeatures],
|
"Constraints": Dict[str, ConstraintFeatures],
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -47,3 +47,6 @@ def test_knapsack() -> None:
|
|||||||
"Category": "eq_capacity",
|
"Category": "eq_capacity",
|
||||||
"User features": [0.0],
|
"User features": [0.0],
|
||||||
}
|
}
|
||||||
|
assert features["Instance"] == {
|
||||||
|
"User features": [67.0, 21.75],
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user