diff --git a/miplearn/features.py b/miplearn/features.py index 88cead9..dcfd4f6 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -107,6 +107,31 @@ class ConstraintFeatures: _clip(features) return features + def __getitem__(self, selected: Tuple[bool, ...]) -> "ConstraintFeatures": + return ConstraintFeatures( + basis_status=self._filter(self.basis_status, selected), + categories=self._filter(self.categories, selected), + dual_values=self._filter(self.dual_values, selected), + names=self._filter(self.names, selected), + lazy=self._filter(self.lazy, selected), + lhs=self._filter(self.lhs, selected), + rhs=self._filter(self.rhs, selected), + sa_rhs_down=self._filter(self.sa_rhs_down, selected), + sa_rhs_up=self._filter(self.sa_rhs_up, selected), + senses=self._filter(self.senses, selected), + slacks=self._filter(self.slacks, selected), + user_features=self._filter(self.user_features, selected), + ) + + def _filter( + self, + obj: Optional[Tuple], + selected: Tuple[bool, ...], + ) -> Optional[Tuple]: + if obj is None: + return None + return tuple(obj[i] for (i, selected_i) in enumerate(selected) if selected_i) + @dataclass class Constraint: diff --git a/tests/test_features.py b/tests/test_features.py index 1b2b938..41a3da5 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -7,6 +7,7 @@ from miplearn.features import ( InstanceFeatures, Constraint, VariableFeatures, + ConstraintFeatures, ) from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.tests import ( @@ -89,3 +90,43 @@ def test_knapsack() -> None: lazy_constraint_count=0, ), ) + + +def test_constraint_getindex() -> None: + cf = ConstraintFeatures( + names=("c1", "c2", "c3"), + rhs=(1.0, 2.0, 3.0), + senses=("=", "<", ">"), + lhs=( + ( + ("x1", 1.0), + ("x2", 1.0), + ), + ( + ("x2", 2.0), + ("x3", 2.0), + ), + ( + ("x3", 3.0), + ("x4", 3.0), + ), + ), + ) + assert_equals( + cf[True, False, True], + ConstraintFeatures( + names=("c1", "c3"), + rhs=(1.0, 3.0), + senses=("=", ">"), + lhs=( + ( + ("x1", 1.0), + ("x2", 1.0), + ), + ( + ("x3", 3.0), + ("x4", 3.0), + ), + ), + ), + )