SplitFreeVars: Preserve var order

This commit is contained in:
2025-10-06 12:51:27 -05:00
parent 295e29c351
commit a4ff65275e

View File

@@ -179,71 +179,55 @@ end
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
Base.@kwdef mutable struct SplitFreeVars <: Transform Base.@kwdef mutable struct SplitFreeVars <: Transform
F::Int = 0 ncols::Int = 0
B::Int = 0 is_var_free::Vector{Bool} = []
free::Vector{Int} = []
others::Vector{Int} = []
end end
function forward!(t::SplitFreeVars, data::ProblemData) function forward!(t::SplitFreeVars, data::ProblemData)
lhs = data.constr_lhs lhs = data.constr_lhs
_, ncols = size(lhs) _, ncols = size(lhs)
free = [i for i = 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])] is_var_free = [!isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i]) for i = 1:ncols]
others = [i for i = 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])] free_idx = findall(is_var_free)
t.F = length(free)
t.B = length(others)
t.free, t.others = free, others
data.obj = [ data.obj = [
data.obj[others] data.obj
data.obj[free] [-data.obj[i] for i in free_idx]
-data.obj[free]
] ]
data.constr_lhs = [lhs[:, others] lhs[:, free] -lhs[:, free]]
data.var_lb = [ data.var_lb = [
data.var_lb[others] [is_var_free[i] ? 0.0 : data.var_lb[i] for i in 1:ncols]
[0.0 for _ in free] [0 for _ in free_idx]
[0.0 for _ in free]
] ]
data.var_ub = [ data.var_ub = [
data.var_ub[others] [is_var_free[i] ? Inf : data.var_ub[i] for i in 1:ncols]
[Inf for _ in free] [Inf for _ in free_idx]
[Inf for _ in free]
] ]
data.var_types = [ data.var_types = [
data.var_types[others] data.var_types
data.var_types[free] [data.var_types[i] for i in free_idx]
data.var_types[free]
] ]
data.var_names = [ data.var_names = [
data.var_names[others] data.var_names
["$(v)_p" for v in data.var_names[free]] ["$(data.var_names[i])_neg" for i in free_idx]
["$(v)_m" for v in data.var_names[free]]
] ]
data.constr_lhs = [lhs -lhs[:, free_idx]]
t.is_var_free, t.ncols = is_var_free, ncols
end end
function backwards!(t::SplitFreeVars, c::ConstraintSet) function backwards!(t::SplitFreeVars, c::ConstraintSet)
# Convert GE constraints into LE ncols, is_var_free = t.ncols, t.is_var_free
nrows, _ = size(c.lhs) free_idx = findall(is_var_free)
ge = [i for i = 1:nrows if isfinite(c.lb[i])]
c.ub[ge], c.lb[ge] = -c.lb[ge], -c.ub[ge]
c.lhs[ge, :] *= -1
# Assert only LE constraints are left (EQ constraints are not supported) for (offset, var_idx) in enumerate(free_idx)
@assert all(c.lb .== -Inf) @assert c.lhs[:, var_idx] == -c.lhs[:, ncols+offset]
# Combine split free variable coefficients: x = x_p - x_m
B, F = t.B, t.F
for i = 1:F
c.lhs[:, B+i] = c.lhs[:, B+i] - c.lhs[:, B+F+i]
end end
c.lhs = c.lhs[:, 1:(B+F)] c.lhs = c.lhs[:, 1:ncols]
end end
function forward(t::SplitFreeVars, p::Vector{Float64})::Vector{Float64} function forward(t::SplitFreeVars, p::Vector{Float64})::Vector{Float64}
ncols, is_var_free = t.ncols, t.is_var_free
free_idx = findall(is_var_free)
return [ return [
p[t.others] [is_var_free[i] ? max(0, p[i]) : p[i] for i in 1:ncols]
max.(p[t.free], 0) [max(0, -p[i]) for i in free_idx]
max.(-p[t.free], 0)
] ]
end end