Convert ConstraintFeatures to dataclass

This commit is contained in:
2021-04-05 20:12:07 -05:00
parent 94084e0669
commit aeed338837
6 changed files with 64 additions and 70 deletions

View File

@@ -3,11 +3,9 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
import sys
from typing import Dict, Tuple, Optional, List, Hashable, Any, TYPE_CHECKING, Set
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set
import numpy as np
from tqdm.auto import tqdm
from miplearn import Classifier
from miplearn.classifiers.counting import CountingClassifier
@@ -71,7 +69,7 @@ class StaticLazyConstraintsComponent(Component):
logger.info("Moving lazy constraints to the pool...")
self.pool = {}
for (cid, cdict) in features.constraints.items():
if cdict["Lazy"] and cid not in self.enforced_cids:
if cdict.lazy and cid not in self.enforced_cids:
self.pool[cid] = LazyConstraint(
cid=cid,
obj=solver.internal_solver.extract_constraint(cid),
@@ -152,10 +150,10 @@ class StaticLazyConstraintsComponent(Component):
x, y = self.sample_xy(features, sample)
category_to_cids: Dict[Hashable, List[str]] = {}
for (cid, cdict) in features.constraints.items():
if "Category" not in cdict or cdict["Category"] is None:
for (cid, cfeatures) in features.constraints.items():
if cfeatures.category is None:
continue
category = cdict["Category"]
category = cfeatures.category
if category not in category_to_cids:
category_to_cids[category] = []
category_to_cids[category] += [cid]
@@ -181,15 +179,15 @@ class StaticLazyConstraintsComponent(Component):
x: Dict = {}
y: Dict = {}
for (cid, cfeatures) in features.constraints.items():
if not cfeatures["Lazy"]:
if not cfeatures.lazy:
continue
category = cfeatures["Category"]
category = cfeatures.category
if category is None:
continue
if category not in x:
x[category] = []
y[category] = []
x[category] += [cfeatures["User features"]]
x[category] += [cfeatures.user_features]
if "LazyStatic: Enforced" in sample:
if cid in sample["LazyStatic: Enforced"]:
y[category] += [[False, True]]

View File

@@ -78,6 +78,7 @@ class ObjectiveValueComponent(Component):
sample: TrainingSample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
assert features.instance is not None
assert features.instance.user_features is not None
x: Dict[Hashable, List[List[float]]] = {}
y: Dict[Hashable, List[List[float]]] = {}
f = list(features.instance.user_features)