diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index d8761b9..ba96463 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -217,7 +217,7 @@ class GurobiSolver(InternalSolver): dual_value, basis_status, sa_rhs_up, sa_rhs_down = None, None, None, None if with_static: - rhs = model.getAttr("rhs", gp_constrs) + rhs = np.array(model.getAttr("rhs", gp_constrs), dtype=float) senses = model.getAttr("sense", gp_constrs) if with_lhs: lhs = [None for _ in gp_constrs] @@ -229,7 +229,7 @@ class GurobiSolver(InternalSolver): ] if self._has_lp_solution: - dual_value = model.getAttr("pi", gp_constrs) + dual_value = np.array(model.getAttr("pi", gp_constrs), dtype=float) basis_status = list( map( _parse_gurobi_cbasis, @@ -237,11 +237,13 @@ class GurobiSolver(InternalSolver): ) ) if with_sa: - sa_rhs_up = model.getAttr("saRhsUp", gp_constrs) - sa_rhs_down = model.getAttr("saRhsLow", gp_constrs) + sa_rhs_up = np.array(model.getAttr("saRhsUp", gp_constrs), dtype=float) + sa_rhs_down = np.array( + model.getAttr("saRhsLow", gp_constrs), dtype=float + ) if self._has_lp_solution or self._has_mip_solution: - slacks = model.getAttr("slack", gp_constrs) + slacks = np.array(model.getAttr("slack", gp_constrs), dtype=float) return Constraints( basis_status=basis_status, diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index c8dbd84..5754aa9 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -69,15 +69,15 @@ class Variables: @dataclass class Constraints: basis_status: Optional[List[str]] = None - dual_values: Optional[List[float]] = None + dual_values: Optional[np.ndarray] = None lazy: Optional[List[bool]] = None lhs: Optional[List[List[Tuple[str, float]]]] = None names: Optional[List[str]] = None - rhs: Optional[List[float]] = None - sa_rhs_down: Optional[List[float]] = None - sa_rhs_up: Optional[List[float]] = None + rhs: Optional[np.ndarray] = None + sa_rhs_down: Optional[np.ndarray] = None + sa_rhs_up: Optional[np.ndarray] = None senses: Optional[List[str]] = None - slacks: Optional[List[float]] = None + slacks: Optional[np.ndarray] = None @staticmethod def from_sample(sample: "Sample") -> "Constraints": @@ -97,15 +97,19 @@ class Constraints: def __getitem__(self, selected: List[bool]) -> "Constraints": return Constraints( basis_status=self._filter(self.basis_status, selected), - dual_values=self._filter(self.dual_values, selected), + dual_values=( + None if self.dual_values is None else self.dual_values[selected] + ), names=self._filter(self.names, selected), lazy=self._filter(self.lazy, selected), lhs=self._filter(self.lhs, selected), - rhs=self._filter(self.rhs, selected), - sa_rhs_down=self._filter(self.sa_rhs_down, selected), - sa_rhs_up=self._filter(self.sa_rhs_up, selected), + rhs=(None if self.rhs is None else self.rhs[selected]), + sa_rhs_down=( + None if self.sa_rhs_down is None else self.sa_rhs_down[selected] + ), + sa_rhs_up=(None if self.sa_rhs_up is None else self.sa_rhs_up[selected]), senses=self._filter(self.senses, selected), - slacks=self._filter(self.slacks, selected), + slacks=(None if self.slacks is None else self.slacks[selected]), ) def _filter( diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index 698c5a4..b3640ec 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -236,11 +236,11 @@ class BasePyomoSolver(InternalSolver): return Constraints( names=_none_if_empty(names), - rhs=_none_if_empty(rhs), + rhs=_none_if_empty(np.array(rhs, dtype=float)), senses=_none_if_empty(senses), lhs=_none_if_empty(lhs), - slacks=_none_if_empty(slacks), - dual_values=_none_if_empty(dual_values), + slacks=_none_if_empty(np.array(slacks, dtype=float)), + dual_values=_none_if_empty(np.array(dual_values, dtype=float)), ) @overrides diff --git a/miplearn/solvers/tests/__init__.py b/miplearn/solvers/tests/__init__.py index b425b0d..c92b07b 100644 --- a/miplearn/solvers/tests/__init__.py +++ b/miplearn/solvers/tests/__init__.py @@ -53,7 +53,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: solver.get_constraints(), Constraints( names=["eq_capacity"], - rhs=[0.0], + rhs=np.array([0.0]), lhs=[ [ ("x[0]", 23.0), @@ -108,11 +108,11 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: solver.get_constraint_attrs(), Constraints( basis_status=["N"], - dual_values=[13.538462], + dual_values=np.array([13.538462]), names=["eq_capacity"], - sa_rhs_down=[-24.0], - sa_rhs_up=[2.0], - slacks=[0.0], + sa_rhs_down=np.array([-24.0]), + sa_rhs_up=np.array([2.0]), + slacks=np.array([0.0]), ), ), ) @@ -153,7 +153,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: solver.get_constraint_attrs(), Constraints( names=["eq_capacity"], - slacks=[0.0], + slacks=np.array([0.0]), ), ), ) @@ -162,7 +162,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: cf = Constraints( names=["cut"], lhs=[[("x[0]", 1.0)]], - rhs=[0.0], + rhs=np.array([0.0]), senses=["<"], ) assert_equals(solver.are_constraints_satisfied(cf), [False]) @@ -175,7 +175,7 @@ def run_basic_usage_tests(solver: InternalSolver) -> None: solver.get_constraint_attrs(), Constraints( names=["eq_capacity", "cut"], - rhs=[0.0, 0.0], + rhs=np.array([0.0, 0.0]), lhs=[ [ ("x[0]", 23.0), @@ -274,7 +274,7 @@ def _equals_preprocess(obj: Any) -> Any: return np.round(obj, decimals=6).tolist() else: return obj.tolist() - elif isinstance(obj, (int, str)): + elif isinstance(obj, (int, str, bool, np.bool_)): return obj elif isinstance(obj, float): return round(obj, 6) diff --git a/tests/features/test_extractor.py b/tests/features/test_extractor.py index d6ce933..2fc1633 100644 --- a/tests/features/test_extractor.py +++ b/tests/features/test_extractor.py @@ -122,7 +122,7 @@ def test_knapsack() -> None: def test_constraint_getindex() -> None: cf = Constraints( names=["c1", "c2", "c3"], - rhs=[1.0, 2.0, 3.0], + rhs=np.array([1.0, 2.0, 3.0]), senses=["=", "<", ">"], lhs=[ [ @@ -143,7 +143,7 @@ def test_constraint_getindex() -> None: cf[[True, False, True]], Constraints( names=["c1", "c3"], - rhs=[1.0, 3.0], + rhs=np.array([1.0, 3.0]), senses=["=", ">"], lhs=[ [