mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
SplitFreeVars: Preserve var order
This commit is contained in:
@@ -179,71 +179,55 @@ end
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
Base.@kwdef mutable struct SplitFreeVars <: Transform
|
||||
F::Int = 0
|
||||
B::Int = 0
|
||||
free::Vector{Int} = []
|
||||
others::Vector{Int} = []
|
||||
ncols::Int = 0
|
||||
is_var_free::Vector{Bool} = []
|
||||
end
|
||||
|
||||
function forward!(t::SplitFreeVars, data::ProblemData)
|
||||
lhs = data.constr_lhs
|
||||
_, ncols = size(lhs)
|
||||
free = [i for i = 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])]
|
||||
others = [i for i = 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])]
|
||||
t.F = length(free)
|
||||
t.B = length(others)
|
||||
t.free, t.others = free, others
|
||||
is_var_free = [!isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i]) for i = 1:ncols]
|
||||
free_idx = findall(is_var_free)
|
||||
data.obj = [
|
||||
data.obj[others]
|
||||
data.obj[free]
|
||||
-data.obj[free]
|
||||
data.obj
|
||||
[-data.obj[i] for i in free_idx]
|
||||
]
|
||||
data.constr_lhs = [lhs[:, others] lhs[:, free] -lhs[:, free]]
|
||||
data.var_lb = [
|
||||
data.var_lb[others]
|
||||
[0.0 for _ in free]
|
||||
[0.0 for _ in free]
|
||||
[is_var_free[i] ? 0.0 : data.var_lb[i] for i in 1:ncols]
|
||||
[0 for _ in free_idx]
|
||||
]
|
||||
data.var_ub = [
|
||||
data.var_ub[others]
|
||||
[Inf for _ in free]
|
||||
[Inf for _ in free]
|
||||
[is_var_free[i] ? Inf : data.var_ub[i] for i in 1:ncols]
|
||||
[Inf for _ in free_idx]
|
||||
]
|
||||
data.var_types = [
|
||||
data.var_types[others]
|
||||
data.var_types[free]
|
||||
data.var_types[free]
|
||||
data.var_types
|
||||
[data.var_types[i] for i in free_idx]
|
||||
]
|
||||
data.var_names = [
|
||||
data.var_names[others]
|
||||
["$(v)_p" for v in data.var_names[free]]
|
||||
["$(v)_m" for v in data.var_names[free]]
|
||||
data.var_names
|
||||
["$(data.var_names[i])_neg" for i in free_idx]
|
||||
]
|
||||
data.constr_lhs = [lhs -lhs[:, free_idx]]
|
||||
t.is_var_free, t.ncols = is_var_free, ncols
|
||||
end
|
||||
|
||||
function backwards!(t::SplitFreeVars, c::ConstraintSet)
|
||||
# Convert GE constraints into LE
|
||||
nrows, _ = size(c.lhs)
|
||||
ge = [i for i = 1:nrows if isfinite(c.lb[i])]
|
||||
c.ub[ge], c.lb[ge] = -c.lb[ge], -c.ub[ge]
|
||||
c.lhs[ge, :] *= -1
|
||||
ncols, is_var_free = t.ncols, t.is_var_free
|
||||
free_idx = findall(is_var_free)
|
||||
|
||||
# Assert only LE constraints are left (EQ constraints are not supported)
|
||||
@assert all(c.lb .== -Inf)
|
||||
|
||||
# Combine split free variable coefficients: x = x_p - x_m
|
||||
B, F = t.B, t.F
|
||||
for i = 1:F
|
||||
c.lhs[:, B+i] = c.lhs[:, B+i] - c.lhs[:, B+F+i]
|
||||
for (offset, var_idx) in enumerate(free_idx)
|
||||
@assert c.lhs[:, var_idx] == -c.lhs[:, ncols+offset]
|
||||
end
|
||||
c.lhs = c.lhs[:, 1:(B+F)]
|
||||
c.lhs = c.lhs[:, 1:ncols]
|
||||
end
|
||||
|
||||
function forward(t::SplitFreeVars, p::Vector{Float64})::Vector{Float64}
|
||||
ncols, is_var_free = t.ncols, t.is_var_free
|
||||
free_idx = findall(is_var_free)
|
||||
return [
|
||||
p[t.others]
|
||||
max.(p[t.free], 0)
|
||||
max.(-p[t.free], 0)
|
||||
[is_var_free[i] ? max(0, p[i]) : p[i] for i in 1:ncols]
|
||||
[max(0, -p[i]) for i in free_idx]
|
||||
]
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user