From b6ea0c5f1b1d865b08bd85e7bf29e51e32a8eca9 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 2 Mar 2021 18:14:07 -0600 Subject: [PATCH] ConstraintFeatures: Store lhs and sense --- miplearn/features.py | 18 +++++++++++------- miplearn/solvers/gurobi.py | 9 +++++++++ miplearn/solvers/internal.py | 13 ++++++++++++- miplearn/solvers/pyomo/base.py | 7 ++++++- miplearn/types.py | 14 ++++++++++++-- tests/test_features.py | 14 +++++++++++--- 6 files changed, 61 insertions(+), 14 deletions(-) diff --git a/miplearn/features.py b/miplearn/features.py index c18a2c2..bf35193 100644 --- a/miplearn/features.py +++ b/miplearn/features.py @@ -2,9 +2,9 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict -from miplearn.types import ModelFeatures +from miplearn.types import ModelFeatures, ConstraintFeatures if TYPE_CHECKING: from miplearn import InternalSolver @@ -15,12 +15,16 @@ class ModelFeaturesExtractor: self, internal_solver: "InternalSolver", ) -> None: - self.internal_solver = internal_solver + self.solver = internal_solver def extract(self) -> ModelFeatures: - rhs = {} - for cid in self.internal_solver.get_constraint_ids(): - rhs[cid] = self.internal_solver.get_constraint_rhs(cid) + constraints: Dict[str, ConstraintFeatures] = {} + for cid in self.solver.get_constraint_ids(): + constraints[cid] = { + "rhs": self.solver.get_constraint_rhs(cid), + "lhs": self.solver.get_constraint_lhs(cid), + "sense": self.solver.get_constraint_sense(cid), + } return { - "ConstraintRHS": rhs, + "constraints": constraints, } diff --git a/miplearn/solvers/gurobi.py b/miplearn/solvers/gurobi.py index 3a735d3..1aeff1a 100644 --- a/miplearn/solvers/gurobi.py +++ b/miplearn/solvers/gurobi.py @@ -339,6 +339,15 @@ class GurobiSolver(InternalSolver): assert self.model is not None return self.model.getConstrByName(cid).rhs + def get_constraint_lhs(self, cid: str) -> Dict[str, float]: + assert self.model is not None + constr = self.model.getConstrByName(cid) + expr = self.model.getRow(constr) + lhs: Dict[str, float] = {} + for i in range(expr.size()): + lhs[expr.getVar(i).varName] = expr.getCoeff(i) + return lhs + def extract_constraint(self, cid): self._raise_if_callback() constr = self.model.getConstrByName(cid) diff --git a/miplearn/solvers/internal.py b/miplearn/solvers/internal.py index 728f1be..48a352a 100644 --- a/miplearn/solvers/internal.py +++ b/miplearn/solvers/internal.py @@ -4,7 +4,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from miplearn.instance import Instance from miplearn.types import ( @@ -162,6 +162,17 @@ class InternalSolver(ABC): """ pass + @abstractmethod + def get_constraint_lhs(self, cid: str) -> Dict[str, float]: + """ + Returns a list of tuples encoding the left-hand side of the constraint. + + The first element of the tuple is the name of the variable and the second + element is the coefficient. For example, the left-hand side of "2 x1 + x2 <= 3" + is encoded as [{"x1": 2, "x2": 1}]. + """ + pass + @abstractmethod def add_constraint(self, cobj: Constraint) -> None: """ diff --git a/miplearn/solvers/pyomo/base.py b/miplearn/solvers/pyomo/base.py index de800be..e5d1b33 100644 --- a/miplearn/solvers/pyomo/base.py +++ b/miplearn/solvers/pyomo/base.py @@ -298,7 +298,9 @@ class BasePyomoSolver(InternalSolver): cobj = self._cname_to_constr[cid] has_ub = cobj.has_ub() has_lb = cobj.has_lb() - assert (not has_lb) or (not has_ub), "range constraints not supported" + assert ( + (not has_lb) or (not has_ub) or cobj.upper() == cobj.lower() + ), "range constraints not supported" if has_lb: return ">" elif has_ub: @@ -313,6 +315,9 @@ class BasePyomoSolver(InternalSolver): else: return cobj.lower() + def get_constraint_lhs(self, cid: str) -> Dict[str, float]: + return {} + def set_constraint_sense(self, cid: str, sense: str) -> None: raise Exception("Not implemented") diff --git a/miplearn/types.py b/miplearn/types.py index 8d58163..35ab0af 100644 --- a/miplearn/types.py +++ b/miplearn/types.py @@ -2,7 +2,7 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. -from typing import Optional, Dict, Callable, Any, Union, Tuple +from typing import Optional, Dict, Callable, Any, Union, Tuple, List from mypy_extensions import TypedDict @@ -71,10 +71,20 @@ LearningSolveStats = TypedDict( total=False, ) +ConstraintFeatures = TypedDict( + "ConstraintFeatures", + { + "rhs": float, + "lhs": Dict[str, float], + "sense": str, + }, + total=False, +) + ModelFeatures = TypedDict( "ModelFeatures", { - "ConstraintRHS": Dict[str, float], + "constraints": Dict[str, ConstraintFeatures], }, total=False, ) diff --git a/tests/test_features.py b/tests/test_features.py index 22c4b85..5f597e9 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,13 +1,14 @@ # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. + +from miplearn import GurobiSolver from miplearn.features import ModelFeaturesExtractor from tests.fixtures.knapsack import get_knapsack_instance -from tests.solvers import get_internal_solvers def test_knapsack() -> None: - for solver_factory in get_internal_solvers(): + for solver_factory in [GurobiSolver]: # Initialize model, instance and internal solver solver = solver_factory() instance = get_knapsack_instance(solver) @@ -20,4 +21,11 @@ def test_knapsack() -> None: # Test constraint features print(solver, features) - assert features["ConstraintRHS"]["eq_capacity"] == 67.0 + assert features["constraints"]["eq_capacity"]["lhs"] == { + "x[0]": 23.0, + "x[1]": 26.0, + "x[2]": 20.0, + "x[3]": 18.0, + } + assert features["constraints"]["eq_capacity"]["sense"] == "<" + assert features["constraints"]["eq_capacity"]["rhs"] == 67.0