mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-09 19:08:51 -06:00
Create ConstraintFeatures
This commit is contained in:
@@ -76,6 +76,38 @@ class VariableFeatures:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintFeatures:
|
||||
basis_status: Optional[Tuple[str, ...]] = None
|
||||
categories: Optional[Tuple[Optional[Hashable], ...]] = None
|
||||
dual_values: Optional[Tuple[float, ...]] = None
|
||||
names: Optional[Tuple[str, ...]] = None
|
||||
lazy: Optional[Tuple[bool, ...]] = None
|
||||
lhs: Optional[Tuple[Tuple[Tuple[str, float], ...], ...]] = None
|
||||
rhs: Optional[Tuple[float, ...]] = None
|
||||
sa_rhs_down: Optional[Tuple[float, ...]] = None
|
||||
sa_rhs_up: Optional[Tuple[float, ...]] = None
|
||||
senses: Optional[Tuple[str, ...]] = None
|
||||
slacks: Optional[Tuple[float, ...]] = None
|
||||
user_features: Optional[Tuple[Optional[Tuple[float, ...]], ...]] = None
|
||||
|
||||
def to_list(self, index: int) -> List[float]:
|
||||
features: List[float] = []
|
||||
for attr in [
|
||||
"dual_values",
|
||||
"rhs",
|
||||
"slacks",
|
||||
]:
|
||||
if getattr(self, attr) is not None:
|
||||
features.append(getattr(self, attr)[index])
|
||||
for attr in ["user_features"]:
|
||||
if getattr(self, attr) is not None:
|
||||
if getattr(self, attr)[index] is not None:
|
||||
features.extend(getattr(self, attr)[index])
|
||||
_clip(features)
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class Constraint:
|
||||
basis_status: Optional[str] = None
|
||||
@@ -147,7 +179,7 @@ class FeaturesExtractor:
|
||||
with_static=with_static,
|
||||
with_sa=self.with_sa,
|
||||
)
|
||||
features.constraints_old = solver.get_constraints(
|
||||
features.constraints_old = solver.get_constraints_old(
|
||||
with_static=with_static,
|
||||
)
|
||||
if with_static:
|
||||
|
||||
Reference in New Issue
Block a user