mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Convert ConstraintFeatures to dataclass
This commit is contained in:
@@ -3,11 +3,9 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Tuple, Optional, List, Hashable, Any, TYPE_CHECKING, Set
|
||||
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from miplearn import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
@@ -71,7 +69,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
logger.info("Moving lazy constraints to the pool...")
|
||||
self.pool = {}
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if cdict["Lazy"] and cid not in self.enforced_cids:
|
||||
if cdict.lazy and cid not in self.enforced_cids:
|
||||
self.pool[cid] = LazyConstraint(
|
||||
cid=cid,
|
||||
obj=solver.internal_solver.extract_constraint(cid),
|
||||
@@ -152,10 +150,10 @@ class StaticLazyConstraintsComponent(Component):
|
||||
|
||||
x, y = self.sample_xy(features, sample)
|
||||
category_to_cids: Dict[Hashable, List[str]] = {}
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if "Category" not in cdict or cdict["Category"] is None:
|
||||
for (cid, cfeatures) in features.constraints.items():
|
||||
if cfeatures.category is None:
|
||||
continue
|
||||
category = cdict["Category"]
|
||||
category = cfeatures.category
|
||||
if category not in category_to_cids:
|
||||
category_to_cids[category] = []
|
||||
category_to_cids[category] += [cid]
|
||||
@@ -181,15 +179,15 @@ class StaticLazyConstraintsComponent(Component):
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
for (cid, cfeatures) in features.constraints.items():
|
||||
if not cfeatures["Lazy"]:
|
||||
if not cfeatures.lazy:
|
||||
continue
|
||||
category = cfeatures["Category"]
|
||||
category = cfeatures.category
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
x[category] += [cfeatures["User features"]]
|
||||
x[category] += [cfeatures.user_features]
|
||||
if "LazyStatic: Enforced" in sample:
|
||||
if cid in sample["LazyStatic: Enforced"]:
|
||||
y[category] += [[False, True]]
|
||||
|
||||
@@ -78,6 +78,7 @@ class ObjectiveValueComponent(Component):
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
assert features.instance is not None
|
||||
assert features.instance.user_features is not None
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
f = list(features.instance.user_features)
|
||||
|
||||
Reference in New Issue
Block a user