mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Add with_lhs argument
This commit is contained in:
@@ -166,8 +166,10 @@ class FeaturesExtractor:
|
||||
def __init__(
|
||||
self,
|
||||
with_sa: bool = True,
|
||||
with_lhs: bool = True,
|
||||
) -> None:
|
||||
self.with_sa = with_sa
|
||||
self.with_lhs = with_lhs
|
||||
|
||||
def extract(
|
||||
self,
|
||||
@@ -183,6 +185,7 @@ class FeaturesExtractor:
|
||||
features.constraints = solver.get_constraints(
|
||||
with_static=with_static,
|
||||
with_sa=self.with_sa,
|
||||
with_lhs=self.with_lhs,
|
||||
)
|
||||
features.constraints_old = solver.get_constraints_old(
|
||||
with_static=with_static,
|
||||
|
||||
@@ -167,6 +167,7 @@ class GurobiSolver(InternalSolver):
|
||||
self,
|
||||
with_static: bool = True,
|
||||
with_sa: bool = True,
|
||||
with_lhs: bool = True,
|
||||
) -> ConstraintFeatures:
|
||||
model = self.model
|
||||
assert model is not None
|
||||
@@ -187,14 +188,15 @@ class GurobiSolver(InternalSolver):
|
||||
if with_static:
|
||||
rhs = tuple(model.getAttr("rhs", gp_constrs))
|
||||
senses = tuple(model.getAttr("sense", gp_constrs))
|
||||
lhs_l: List = [None for _ in gp_constrs]
|
||||
for (i, gp_constr) in enumerate(gp_constrs):
|
||||
expr = model.getRow(gp_constr)
|
||||
lhs_l[i] = tuple(
|
||||
(self._var_names[expr.getVar(j).index], expr.getCoeff(j))
|
||||
for j in range(expr.size())
|
||||
)
|
||||
lhs = tuple(lhs_l)
|
||||
if with_lhs:
|
||||
lhs_l: List = [None for _ in gp_constrs]
|
||||
for (i, gp_constr) in enumerate(gp_constrs):
|
||||
expr = model.getRow(gp_constr)
|
||||
lhs_l[i] = tuple(
|
||||
(self._var_names[expr.getVar(j).index], expr.getCoeff(j))
|
||||
for j in range(expr.size())
|
||||
)
|
||||
lhs = tuple(lhs_l)
|
||||
|
||||
if self._has_lp_solution:
|
||||
dual_value = tuple(model.getAttr("pi", gp_constrs))
|
||||
|
||||
@@ -173,6 +173,7 @@ class InternalSolver(ABC, EnforceOverrides):
|
||||
self,
|
||||
with_static: bool = True,
|
||||
with_sa: bool = True,
|
||||
with_lhs: bool = True,
|
||||
) -> ConstraintFeatures:
|
||||
pass
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ class BasePyomoSolver(InternalSolver):
|
||||
self,
|
||||
with_static: bool = True,
|
||||
with_sa: bool = True,
|
||||
with_lhs: bool = True,
|
||||
) -> ConstraintFeatures:
|
||||
model = self.model
|
||||
assert model is not None
|
||||
@@ -162,26 +163,32 @@ class BasePyomoSolver(InternalSolver):
|
||||
senses.append("=")
|
||||
rhs.append(float(c.upper()))
|
||||
|
||||
# Extract LHS
|
||||
lhsc = []
|
||||
expr = c.body
|
||||
if isinstance(expr, SumExpression):
|
||||
for term in expr._args_:
|
||||
if isinstance(term, MonomialTermExpression):
|
||||
lhsc.append((term._args_[1].name, float(term._args_[0])))
|
||||
elif isinstance(term, _GeneralVarData):
|
||||
lhsc.append((term.name, 1.0))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unknown term type: {term.__class__.__name__}"
|
||||
)
|
||||
elif isinstance(expr, _GeneralVarData):
|
||||
lhsc.append((expr.name, 1.0))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unknown expression type: {expr.__class__.__name__}"
|
||||
)
|
||||
lhs.append(tuple(lhsc))
|
||||
if with_lhs:
|
||||
# Extract LHS
|
||||
lhsc = []
|
||||
expr = c.body
|
||||
if isinstance(expr, SumExpression):
|
||||
for term in expr._args_:
|
||||
if isinstance(term, MonomialTermExpression):
|
||||
lhsc.append(
|
||||
(
|
||||
term._args_[1].name,
|
||||
float(term._args_[0]),
|
||||
)
|
||||
)
|
||||
elif isinstance(term, _GeneralVarData):
|
||||
lhsc.append((term.name, 1.0))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unknown term type: {term.__class__.__name__}"
|
||||
)
|
||||
elif isinstance(expr, _GeneralVarData):
|
||||
lhsc.append((expr.name, 1.0))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unknown expression type: {expr.__class__.__name__}"
|
||||
)
|
||||
lhs.append(tuple(lhsc))
|
||||
|
||||
# Extract dual values
|
||||
if self._has_lp_solution:
|
||||
|
||||
Reference in New Issue
Block a user