mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Implement ConstraintFeatures.__getitem__
This commit is contained in:
@@ -107,6 +107,31 @@ class ConstraintFeatures:
|
|||||||
_clip(features)
|
_clip(features)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
def __getitem__(self, selected: Tuple[bool, ...]) -> "ConstraintFeatures":
|
||||||
|
return ConstraintFeatures(
|
||||||
|
basis_status=self._filter(self.basis_status, selected),
|
||||||
|
categories=self._filter(self.categories, selected),
|
||||||
|
dual_values=self._filter(self.dual_values, selected),
|
||||||
|
names=self._filter(self.names, selected),
|
||||||
|
lazy=self._filter(self.lazy, selected),
|
||||||
|
lhs=self._filter(self.lhs, selected),
|
||||||
|
rhs=self._filter(self.rhs, selected),
|
||||||
|
sa_rhs_down=self._filter(self.sa_rhs_down, selected),
|
||||||
|
sa_rhs_up=self._filter(self.sa_rhs_up, selected),
|
||||||
|
senses=self._filter(self.senses, selected),
|
||||||
|
slacks=self._filter(self.slacks, selected),
|
||||||
|
user_features=self._filter(self.user_features, selected),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _filter(
|
||||||
|
self,
|
||||||
|
obj: Optional[Tuple],
|
||||||
|
selected: Tuple[bool, ...],
|
||||||
|
) -> Optional[Tuple]:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
return tuple(obj[i] for (i, selected_i) in enumerate(selected) if selected_i)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Constraint:
|
class Constraint:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from miplearn.features import (
|
|||||||
InstanceFeatures,
|
InstanceFeatures,
|
||||||
Constraint,
|
Constraint,
|
||||||
VariableFeatures,
|
VariableFeatures,
|
||||||
|
ConstraintFeatures,
|
||||||
)
|
)
|
||||||
from miplearn.solvers.gurobi import GurobiSolver
|
from miplearn.solvers.gurobi import GurobiSolver
|
||||||
from miplearn.solvers.tests import (
|
from miplearn.solvers.tests import (
|
||||||
@@ -89,3 +90,43 @@ def test_knapsack() -> None:
|
|||||||
lazy_constraint_count=0,
|
lazy_constraint_count=0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_constraint_getindex() -> None:
|
||||||
|
cf = ConstraintFeatures(
|
||||||
|
names=("c1", "c2", "c3"),
|
||||||
|
rhs=(1.0, 2.0, 3.0),
|
||||||
|
senses=("=", "<", ">"),
|
||||||
|
lhs=(
|
||||||
|
(
|
||||||
|
("x1", 1.0),
|
||||||
|
("x2", 1.0),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
("x2", 2.0),
|
||||||
|
("x3", 2.0),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
("x3", 3.0),
|
||||||
|
("x4", 3.0),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert_equals(
|
||||||
|
cf[True, False, True],
|
||||||
|
ConstraintFeatures(
|
||||||
|
names=("c1", "c3"),
|
||||||
|
rhs=(1.0, 3.0),
|
||||||
|
senses=("=", ">"),
|
||||||
|
lhs=(
|
||||||
|
(
|
||||||
|
("x1", 1.0),
|
||||||
|
("x2", 1.0),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
("x3", 3.0),
|
||||||
|
("x4", 3.0),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user