mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Convert ConstraintFeatures to dataclass
This commit is contained in:
@@ -3,11 +3,9 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Tuple, Optional, List, Hashable, Any, TYPE_CHECKING, Set
|
||||
from typing import Dict, Tuple, List, Hashable, Any, TYPE_CHECKING, Set
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from miplearn import Classifier
|
||||
from miplearn.classifiers.counting import CountingClassifier
|
||||
@@ -71,7 +69,7 @@ class StaticLazyConstraintsComponent(Component):
|
||||
logger.info("Moving lazy constraints to the pool...")
|
||||
self.pool = {}
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if cdict["Lazy"] and cid not in self.enforced_cids:
|
||||
if cdict.lazy and cid not in self.enforced_cids:
|
||||
self.pool[cid] = LazyConstraint(
|
||||
cid=cid,
|
||||
obj=solver.internal_solver.extract_constraint(cid),
|
||||
@@ -152,10 +150,10 @@ class StaticLazyConstraintsComponent(Component):
|
||||
|
||||
x, y = self.sample_xy(features, sample)
|
||||
category_to_cids: Dict[Hashable, List[str]] = {}
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if "Category" not in cdict or cdict["Category"] is None:
|
||||
for (cid, cfeatures) in features.constraints.items():
|
||||
if cfeatures.category is None:
|
||||
continue
|
||||
category = cdict["Category"]
|
||||
category = cfeatures.category
|
||||
if category not in category_to_cids:
|
||||
category_to_cids[category] = []
|
||||
category_to_cids[category] += [cid]
|
||||
@@ -181,15 +179,15 @@ class StaticLazyConstraintsComponent(Component):
|
||||
x: Dict = {}
|
||||
y: Dict = {}
|
||||
for (cid, cfeatures) in features.constraints.items():
|
||||
if not cfeatures["Lazy"]:
|
||||
if not cfeatures.lazy:
|
||||
continue
|
||||
category = cfeatures["Category"]
|
||||
category = cfeatures.category
|
||||
if category is None:
|
||||
continue
|
||||
if category not in x:
|
||||
x[category] = []
|
||||
y[category] = []
|
||||
x[category] += [cfeatures["User features"]]
|
||||
x[category] += [cfeatures.user_features]
|
||||
if "LazyStatic: Enforced" in sample:
|
||||
if cid in sample["LazyStatic: Enforced"]:
|
||||
y[category] += [[False, True]]
|
||||
|
||||
@@ -78,6 +78,7 @@ class ObjectiveValueComponent(Component):
|
||||
sample: TrainingSample,
|
||||
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
|
||||
assert features.instance is not None
|
||||
assert features.instance.user_features is not None
|
||||
x: Dict[Hashable, List[List[float]]] = {}
|
||||
y: Dict[Hashable, List[List[float]]] = {}
|
||||
f = list(features.instance.user_features)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
import numbers
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Dict, Hashable
|
||||
import numbers
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from miplearn.types import (
|
||||
Features,
|
||||
@@ -87,17 +87,15 @@ class FeaturesExtractor:
|
||||
f"Constraint features must be a list of floats. "
|
||||
f"Found {type(user_features[0]).__name__} instead for cid={cid}."
|
||||
)
|
||||
constraints[cid] = {
|
||||
"RHS": self.solver.get_constraint_rhs(cid),
|
||||
"LHS": self.solver.get_constraint_lhs(cid),
|
||||
"Sense": self.solver.get_constraint_sense(cid),
|
||||
"Category": category,
|
||||
"User features": user_features,
|
||||
}
|
||||
constraints[cid] = ConstraintFeatures(
|
||||
rhs=self.solver.get_constraint_rhs(cid),
|
||||
lhs=self.solver.get_constraint_lhs(cid),
|
||||
sense=self.solver.get_constraint_sense(cid),
|
||||
category=category,
|
||||
user_features=user_features,
|
||||
)
|
||||
if has_static_lazy:
|
||||
constraints[cid]["Lazy"] = instance.is_constraint_lazy(cid)
|
||||
else:
|
||||
constraints[cid]["Lazy"] = False
|
||||
constraints[cid].lazy = instance.is_constraint_lazy(cid)
|
||||
return constraints
|
||||
|
||||
@staticmethod
|
||||
@@ -118,7 +116,7 @@ class FeaturesExtractor:
|
||||
)
|
||||
lazy_count = 0
|
||||
for (cid, cdict) in features.constraints.items():
|
||||
if cdict["Lazy"]:
|
||||
if cdict.lazy:
|
||||
lazy_count += 1
|
||||
return InstanceFeatures(
|
||||
user_features=user_features,
|
||||
|
||||
@@ -81,7 +81,7 @@ LearningSolveStats = TypedDict(
|
||||
|
||||
@dataclass
|
||||
class InstanceFeatures:
|
||||
user_features: List[float]
|
||||
user_features: Optional[List[float]] = None
|
||||
lazy_constraint_count: int = 0
|
||||
|
||||
|
||||
@@ -91,18 +91,14 @@ class VariableFeatures:
|
||||
user_features: Optional[List[float]] = None
|
||||
|
||||
|
||||
ConstraintFeatures = TypedDict(
|
||||
"ConstraintFeatures",
|
||||
{
|
||||
"RHS": float,
|
||||
"LHS": Dict[str, float],
|
||||
"Sense": str,
|
||||
"Category": Optional[Hashable],
|
||||
"User features": Optional[List[float]],
|
||||
"Lazy": bool,
|
||||
},
|
||||
total=False,
|
||||
)
|
||||
@dataclass
|
||||
class ConstraintFeatures:
|
||||
rhs: Optional[float] = None
|
||||
lhs: Optional[Dict[str, float]] = None
|
||||
sense: Optional[str] = None
|
||||
category: Optional[Hashable] = None
|
||||
user_features: Optional[List[float]] = None
|
||||
lazy: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user