SplitFreeVars: Preserve var order

This commit is contained in:
2025-10-06 12:51:27 -05:00
parent 295e29c351
commit a4ff65275e

View File

@@ -179,71 +179,55 @@ end
# -----------------------------------------------------------------------------
Base.@kwdef mutable struct SplitFreeVars <: Transform
F::Int = 0
B::Int = 0
free::Vector{Int} = []
others::Vector{Int} = []
ncols::Int = 0
is_var_free::Vector{Bool} = []
end
function forward!(t::SplitFreeVars, data::ProblemData)
lhs = data.constr_lhs
_, ncols = size(lhs)
free = [i for i = 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])]
others = [i for i = 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])]
t.F = length(free)
t.B = length(others)
t.free, t.others = free, others
is_var_free = [!isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i]) for i = 1:ncols]
free_idx = findall(is_var_free)
data.obj = [
data.obj[others]
data.obj[free]
-data.obj[free]
data.obj
[-data.obj[i] for i in free_idx]
]
data.constr_lhs = [lhs[:, others] lhs[:, free] -lhs[:, free]]
data.var_lb = [
data.var_lb[others]
[0.0 for _ in free]
[0.0 for _ in free]
[is_var_free[i] ? 0.0 : data.var_lb[i] for i in 1:ncols]
[0 for _ in free_idx]
]
data.var_ub = [
data.var_ub[others]
[Inf for _ in free]
[Inf for _ in free]
[is_var_free[i] ? Inf : data.var_ub[i] for i in 1:ncols]
[Inf for _ in free_idx]
]
data.var_types = [
data.var_types[others]
data.var_types[free]
data.var_types[free]
data.var_types
[data.var_types[i] for i in free_idx]
]
data.var_names = [
data.var_names[others]
["$(v)_p" for v in data.var_names[free]]
["$(v)_m" for v in data.var_names[free]]
data.var_names
["$(data.var_names[i])_neg" for i in free_idx]
]
data.constr_lhs = [lhs -lhs[:, free_idx]]
t.is_var_free, t.ncols = is_var_free, ncols
end
function backwards!(t::SplitFreeVars, c::ConstraintSet)
# Convert GE constraints into LE
nrows, _ = size(c.lhs)
ge = [i for i = 1:nrows if isfinite(c.lb[i])]
c.ub[ge], c.lb[ge] = -c.lb[ge], -c.ub[ge]
c.lhs[ge, :] *= -1
ncols, is_var_free = t.ncols, t.is_var_free
free_idx = findall(is_var_free)
# Assert only LE constraints are left (EQ constraints are not supported)
@assert all(c.lb .== -Inf)
# Combine split free variable coefficients: x = x_p - x_m
B, F = t.B, t.F
for i = 1:F
c.lhs[:, B+i] = c.lhs[:, B+i] - c.lhs[:, B+F+i]
for (offset, var_idx) in enumerate(free_idx)
@assert c.lhs[:, var_idx] == -c.lhs[:, ncols+offset]
end
c.lhs = c.lhs[:, 1:(B+F)]
c.lhs = c.lhs[:, 1:ncols]
end
function forward(t::SplitFreeVars, p::Vector{Float64})::Vector{Float64}
ncols, is_var_free = t.ncols, t.is_var_free
free_idx = findall(is_var_free)
return [
p[t.others]
max.(p[t.free], 0)
max.(-p[t.free], 0)
[is_var_free[i] ? max(0, p[i]) : p[i] for i in 1:ncols]
[max(0, -p[i]) for i in free_idx]
]
end