mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Convert InstanceFeatures into dataclass
This commit is contained in:
@@ -64,7 +64,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
assert features.instance is not None
|
||||
assert features.constraints is not None
|
||||
|
||||
if not features.instance["Lazy constraint count"] == 0:
|
||||
if not features.instance.lazy_constraint_count == 0:
|
||||
logger.info("Instance does not have static lazy constraints. Skipping.")
|
||||
logger.info("Predicting required lazy constraints...")
|
||||
self.enforced_cids = set(self.sample_predict(features, training_data))
|
||||
|
||||
@@ -80,7 +80,7 @@ class ObjectiveValueComponent(Component):
|
||||
assert features.instance is not None
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
f = list(features.instance["User features"])
|
||||
f = list(features.instance.user_features)
|
||||
if "LP value" in sample and sample["LP value"] is not None:
|
||||
f += [sample["LP value"]]
|
||||
for c in ["Upper bound", "Lower bound"]:
|
||||
|
||||
@@ -120,7 +120,7 @@ class FeaturesExtractor:
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if cdict["Lazy"]:
|
||||
lazy_count += 1
|
||||
return {
|
||||
"User features": user_features,
|
||||
"Lazy constraint count": lazy_count,
|
||||
}
|
||||
return InstanceFeatures(
|
||||
user_features=user_features,
|
||||
lazy_constraint_count=lazy_count,
|
||||
)
|
||||
|
||||
@@ -78,14 +78,11 @@ LearningSolveStats = TypedDict(
|
||||
total=False,
|
||||
)
|
||||
|
||||
InstanceFeatures = TypedDict(
|
||||
"InstanceFeatures",
|
||||
{
|
||||
"User features": List[float],
|
||||
"Lazy constraint count": int,
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class InstanceFeatures:
|
||||
user_features: List[float]
|
||||
lazy_constraint_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -12,7 +12,12 @@ from miplearn import LearningSolver, InternalSolver, Instance
|
||||
from miplearn.classifiers import Classifier
|
||||
from miplearn.classifiers.threshold import Threshold, MinProbabilityThreshold
|
||||
from miplearn.components.lazy_static import StaticLazyConstraintsComponent
|
||||
from miplearn.types import TrainingSample, Features, LearningSolveStats
|
||||
from miplearn.types import (
|
||||
TrainingSample,
|
||||
Features,
|
||||
LearningSolveStats,
|
||||
InstanceFeatures,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -25,9 +30,10 @@ def sample() -> TrainingSample:
|
||||
@pytest.fixture
|
||||
def features() -> Features:
|
||||
return Features(
|
||||
instance={
|
||||
"Lazy constraint count": 4,
|
||||
},
|
||||
instance=InstanceFeatures(
|
||||
user_features=[0],
|
||||
lazy_constraint_count=4,
|
||||
),
|
||||
constraints={
|
||||
"c1": {
|
||||
"Category": "type-a",
|
||||
|
||||
@@ -9,7 +9,7 @@ from numpy.testing import assert_array_equal
|
||||
|
||||
from miplearn import GurobiPyomoSolver, LearningSolver, Regressor
|
||||
from miplearn.components.objective import ObjectiveValueComponent
|
||||
from miplearn.types import TrainingSample, Features
|
||||
from miplearn.types import TrainingSample, Features, InstanceFeatures
|
||||
from tests.fixtures.knapsack import get_knapsack_instance
|
||||
|
||||
import numpy as np
|
||||
@@ -18,9 +18,9 @@ import numpy as np
|
||||
@pytest.fixture
|
||||
def features() -> Features:
|
||||
return Features(
|
||||
instance={
|
||||
"User features": [1.0, 2.0],
|
||||
}
|
||||
instance=InstanceFeatures(
|
||||
user_features=[1.0, 2.0],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from miplearn import GurobiSolver
|
||||
from miplearn.features import FeaturesExtractor
|
||||
from miplearn.types import VariableFeatures
|
||||
from miplearn.types import VariableFeatures, InstanceFeatures
|
||||
from tests.fixtures.knapsack import get_knapsack_instance
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_knapsack() -> None:
|
||||
"User features": [0.0],
|
||||
}
|
||||
}
|
||||
assert instance.features.instance == {
|
||||
"User features": [67.0, 21.75],
|
||||
"Lazy constraint count": 0,
|
||||
}
|
||||
assert instance.features.instance == InstanceFeatures(
|
||||
user_features=[67.0, 21.75],
|
||||
lazy_constraint_count=0,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user