Convert ConstraintFeatures to dataclass

master
Alinson S. Xavier 5 years ago
parent 94084e0669
commit aeed338837
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -3,11 +3,9 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
import sys from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set
from typing import Dict, Tuple, Optional, List, Hashable, Any, TYPE_CHECKING, Set
import numpy as np import numpy as np
from tqdm.auto import tqdm
from miplearn import Classifier from miplearn import Classifier
from miplearn.classifiers.counting import CountingClassifier from miplearn.classifiers.counting import CountingClassifier
@ -71,7 +69,7 @@ class StaticLazyConstraintsComponent(Component):
logger.info("Moving lazy constraints to the pool...") logger.info("Moving lazy constraints to the pool...")
self.pool = {} self.pool = {}
for (cid, cdict) in features.constraints.items(): for (cid, cdict) in features.constraints.items():
if cdict["Lazy"] and cid not in self.enforced_cids: if cdict.lazy and cid not in self.enforced_cids:
self.pool[cid] = LazyConstraint( self.pool[cid] = LazyConstraint(
cid=cid, cid=cid,
obj=solver.internal_solver.extract_constraint(cid), obj=solver.internal_solver.extract_constraint(cid),
@ -152,10 +150,10 @@ class StaticLazyConstraintsComponent(Component):
x, y = self.sample_xy(features, sample) x, y = self.sample_xy(features, sample)
category_to_cids: Dict[Hashable, List[str]] = {} category_to_cids: Dict[Hashable, List[str]] = {}
for (cid, cdict) in features.constraints.items(): for (cid, cfeatures) in features.constraints.items():
if "Category" not in cdict or cdict["Category"] is None: if cfeatures.category is None:
continue continue
category = cdict["Category"] category = cfeatures.category
if category not in category_to_cids: if category not in category_to_cids:
category_to_cids[category] = [] category_to_cids[category] = []
category_to_cids[category] += [cid] category_to_cids[category] += [cid]
@ -181,15 +179,15 @@ class StaticLazyConstraintsComponent(Component):
x: Dict = {} x: Dict = {}
y: Dict = {} y: Dict = {}
for (cid, cfeatures) in features.constraints.items(): for (cid, cfeatures) in features.constraints.items():
if not cfeatures["Lazy"]: if not cfeatures.lazy:
continue continue
category = cfeatures["Category"] category = cfeatures.category
if category is None: if category is None:
continue continue
if category not in x: if category not in x:
x[category] = [] x[category] = []
y[category] = [] y[category] = []
x[category] += [cfeatures["User features"]] x[category] += [cfeatures.user_features]
if "LazyStatic: Enforced" in sample: if "LazyStatic: Enforced" in sample:
if cid in sample["LazyStatic: Enforced"]: if cid in sample["LazyStatic: Enforced"]:
y[category] += [[False, True]] y[category] += [[False, True]]

@ -78,6 +78,7 @@ class ObjectiveValueComponent(Component):
sample: TrainingSample, sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]: ) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert features.instance is not None assert features.instance is not None
assert features.instance.user_features is not None
x: Dict[Hashable, List[List[float]]] = {} x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {} y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance.user_features) f = list(features.instance.user_features)

@ -2,9 +2,9 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import numbers
import collections import collections
from typing import TYPE_CHECKING, Dict, Hashable import numbers
from typing import TYPE_CHECKING, Dict
from miplearn.types import ( from miplearn.types import (
Features, Features,
@ -87,17 +87,15 @@ class FeaturesExtractor:
f"Constraint features must be a list of floats. " f"Constraint features must be a list of floats. "
f"Found {type(user_features[0]).__name__} instead for cid={cid}." f"Found {type(user_features[0]).__name__} instead for cid={cid}."
) )
constraints[cid] = { constraints[cid] = ConstraintFeatures(
"RHS": self.solver.get_constraint_rhs(cid), rhs=self.solver.get_constraint_rhs(cid),
"LHS": self.solver.get_constraint_lhs(cid), lhs=self.solver.get_constraint_lhs(cid),
"Sense": self.solver.get_constraint_sense(cid), sense=self.solver.get_constraint_sense(cid),
"Category": category, category=category,
"User features": user_features, user_features=user_features,
} )
if has_static_lazy: if has_static_lazy:
constraints[cid]["Lazy"] = instance.is_constraint_lazy(cid) constraints[cid].lazy = instance.is_constraint_lazy(cid)
else:
constraints[cid]["Lazy"] = False
return constraints return constraints
@staticmethod @staticmethod
@ -118,7 +116,7 @@ class FeaturesExtractor:
) )
lazy_count = 0 lazy_count = 0
for (cid, cdict) in features.constraints.items(): for (cid, cdict) in features.constraints.items():
if cdict["Lazy"]: if cdict.lazy:
lazy_count += 1 lazy_count += 1
return InstanceFeatures( return InstanceFeatures(
user_features=user_features, user_features=user_features,

@ -81,7 +81,7 @@ LearningSolveStats = TypedDict(
@dataclass @dataclass
class InstanceFeatures: class InstanceFeatures:
user_features: List[float] user_features: Optional[List[float]] = None
lazy_constraint_count: int = 0 lazy_constraint_count: int = 0
@ -91,18 +91,14 @@ class VariableFeatures:
user_features: Optional[List[float]] = None user_features: Optional[List[float]] = None
ConstraintFeatures = TypedDict( @dataclass
"ConstraintFeatures", class ConstraintFeatures:
{ rhs: Optional[float] = None
"RHS": float, lhs: Optional[Dict[str, float]] = None
"LHS": Dict[str, float], sense: Optional[str] = None
"Sense": str, category: Optional[Hashable] = None
"Category": Optional[Hashable], user_features: Optional[List[float]] = None
"User features": Optional[List[float]], lazy: bool = False
"Lazy": bool,
},
total=False,
)
@dataclass @dataclass

@ -17,6 +17,7 @@ from miplearn.types import (
Features, Features,
LearningSolveStats, LearningSolveStats,
InstanceFeatures, InstanceFeatures,
ConstraintFeatures,
) )
@ -35,31 +36,31 @@ def features() -> Features:
lazy_constraint_count=4, lazy_constraint_count=4,
), ),
constraints={ constraints={
"c1": { "c1": ConstraintFeatures(
"Category": "type-a", category="type-a",
"User features": [1.0, 1.0], user_features=[1.0, 1.0],
"Lazy": True, lazy=True,
}, ),
"c2": { "c2": ConstraintFeatures(
"Category": "type-a", category="type-a",
"User features": [1.0, 2.0], user_features=[1.0, 2.0],
"Lazy": True, lazy=True,
}, ),
"c3": { "c3": ConstraintFeatures(
"Category": "type-a", category="type-a",
"User features": [1.0, 3.0], user_features=[1.0, 3.0],
"Lazy": True, lazy=True,
}, ),
"c4": { "c4": ConstraintFeatures(
"Category": "type-b", category="type-b",
"User features": [1.0, 4.0, 0.0], user_features=[1.0, 4.0, 0.0],
"Lazy": True, lazy=True,
}, ),
"c5": { "c5": ConstraintFeatures(
"Category": "type-b", category="type-b",
"User features": [1.0, 5.0, 0.0], user_features=[1.0, 5.0, 0.0],
"Lazy": False, lazy=False,
}, ),
}, },
) )

@ -4,7 +4,7 @@
from miplearn import GurobiSolver from miplearn import GurobiSolver
from miplearn.features import FeaturesExtractor from miplearn.features import FeaturesExtractor
from miplearn.types import VariableFeatures, InstanceFeatures from miplearn.types import VariableFeatures, InstanceFeatures, ConstraintFeatures
from tests.fixtures.knapsack import get_knapsack_instance from tests.fixtures.knapsack import get_knapsack_instance
@ -36,19 +36,19 @@ def test_knapsack() -> None:
} }
} }
assert instance.features.constraints == { assert instance.features.constraints == {
"eq_capacity": { "eq_capacity": ConstraintFeatures(
"LHS": { lhs={
"x[0]": 23.0, "x[0]": 23.0,
"x[1]": 26.0, "x[1]": 26.0,
"x[2]": 20.0, "x[2]": 20.0,
"x[3]": 18.0, "x[3]": 18.0,
}, },
"Sense": "<", sense="<",
"RHS": 67.0, rhs=67.0,
"Lazy": False, lazy=False,
"Category": "eq_capacity", category="eq_capacity",
"User features": [0.0], user_features=[0.0],
} )
} }
assert instance.features.instance == InstanceFeatures( assert instance.features.instance == InstanceFeatures(
user_features=[67.0, 21.75], user_features=[67.0, 21.75],

Loading…
Cancel
Save