FisSal2011: Speed up hash calculation

fs11_01
Alinson S. Xavier 2 months ago
parent 1296182744
commit 37f3abee42

@ -253,12 +253,13 @@ function compute_gmi(data::ProblemData, tableau::Tableau)::ConstraintSet
end end
@timeit "Pre-allocation" begin @timeit "Pre-allocation" begin
ub::Vector{Float64} = fill(Inf, nrows) cut_ub::Vector{Float64} = fill(Inf, nrows)
lb::Vector{Float64} = fill(0.999, nrows) cut_lb::Vector{Float64} = fill(0.999, nrows)
nnz_tableau::Int = length(tableau_I) nnz_tableau::Int = length(tableau_I)
lhs_I::Vector{Int} = Vector{Int}(undef, nnz_tableau) cut_lhs_I::Vector{Int} = Vector{Int}(undef, nnz_tableau)
lhs_J::Vector{Int} = Vector{Int}(undef, nnz_tableau) cut_lhs_J::Vector{Int} = Vector{Int}(undef, nnz_tableau)
lhs_V::Vector{Float64} = Vector{Float64}(undef, nnz_tableau) cut_lhs_V::Vector{Float64} = Vector{Float64}(undef, nnz_tableau)
cut_hash::Vector{UInt64} = zeros(UInt64, nrows)
nnz_count::Int = 0 nnz_count::Int = 0
end end
@ -290,22 +291,24 @@ function compute_gmi(data::ProblemData, tableau::Tableau)::ConstraintSet
# Store if significant # Store if significant
if abs(v) > 1e-8 if abs(v) > 1e-8
nnz_count += 1 nnz_count += 1
lhs_I[nnz_count] = i cut_lhs_I[nnz_count] = i
lhs_J[nnz_count] = j cut_lhs_J[nnz_count] = j
lhs_V[nnz_count] = v cut_lhs_V[nnz_count] = v
cut_hash[i] = hash(j, cut_hash[i])
cut_hash[i] = hash(v, cut_hash[i])
end end
end end
end end
@timeit "Resize arrays to actual size" begin @timeit "Resize arrays to actual size" begin
resize!(lhs_I, nnz_count) resize!(cut_lhs_I, nnz_count)
resize!(lhs_J, nnz_count) resize!(cut_lhs_J, nnz_count)
resize!(lhs_V, nnz_count) resize!(cut_lhs_V, nnz_count)
end end
@timeit "Convert to ConstraintSet" begin @timeit "Convert to ConstraintSet" begin
lhs::SparseMatrixCSC = sparse(lhs_I, lhs_J, lhs_V, nrows, ncols) cut_lhs::SparseMatrixCSC = sparse(cut_lhs_I, cut_lhs_J, cut_lhs_V, nrows, ncols)
cs::ConstraintSet = ConstraintSet(; lhs, ub, lb) cs::ConstraintSet = ConstraintSet(; lhs=cut_lhs, ub=cut_ub, lb=cut_lb, hash=cut_hash)
end end
return cs return cs

@ -486,8 +486,7 @@ function collect_gmi_FisSal2011(
@timeit "Compute cut hashses" begin @timeit "Compute cut hashses" begin
unique_indices = Int[] unique_indices = Int[]
for i in 1:ncuts for i in 1:ncuts
cut_data = (cuts_s.lhs[i, :], cuts_s.lb[i], cuts_s.ub[i]) cut_hash = cuts_s.hash[i]
cut_hash = hash(cut_data)
if !(cut_hash in pool_cut_hashes) if !(cut_hash in pool_cut_hashes)
push!(pool_cut_hashes, cut_hash) push!(pool_cut_hashes, cut_hash)
push!(unique_indices, i) push!(unique_indices, i)
@ -499,7 +498,8 @@ function collect_gmi_FisSal2011(
pool = ConstraintSet( pool = ConstraintSet(
lhs = sparse(cuts_s.lhs[unique_indices, :]'), lhs = sparse(cuts_s.lhs[unique_indices, :]'),
lb = cuts_s.lb[unique_indices], lb = cuts_s.lb[unique_indices],
ub = cuts_s.ub[unique_indices] ub = cuts_s.ub[unique_indices],
hash = cuts_s.hash[unique_indices],
) )
ncuts_unique = length(unique_indices) ncuts_unique = length(unique_indices)
multipliers_curr = zeros(ncuts_unique) multipliers_curr = zeros(ncuts_unique)
@ -511,6 +511,7 @@ function collect_gmi_FisSal2011(
pool.lhs = [pool.lhs sparse(cuts_s.lhs[unique_indices, :]')] pool.lhs = [pool.lhs sparse(cuts_s.lhs[unique_indices, :]')]
pool.lb = [pool.lb; cuts_s.lb[unique_indices]] pool.lb = [pool.lb; cuts_s.lb[unique_indices]]
pool.ub = [pool.ub; cuts_s.ub[unique_indices]] pool.ub = [pool.ub; cuts_s.ub[unique_indices]]
pool.hash = [pool.hash; cuts_s.hash[unique_indices]]
multipliers_curr = [multipliers_curr; zeros(ncuts_unique)] multipliers_curr = [multipliers_curr; zeros(ncuts_unique)]
multipliers_best = [multipliers_best; zeros(ncuts_unique)] multipliers_best = [multipliers_best; zeros(ncuts_unique)]
pool_cut_age = [pool_cut_age; zeros(ncuts_unique)] pool_cut_age = [pool_cut_age; zeros(ncuts_unique)]
@ -531,14 +532,15 @@ function collect_gmi_FisSal2011(
end end
@timeit "Update cut hashes" begin @timeit "Update cut hashes" begin
for idx in idx_remove for idx in idx_remove
cut_data = (pool.lhs[:, idx], pool.lb[idx], pool.ub[idx]) cut_hash = pool.hash[idx]
delete!(pool_cut_hashes, hash(cut_data)) delete!(pool_cut_hashes, cut_hash)
end end
end end
@timeit "Update cut pool" begin @timeit "Update cut pool" begin
pool.lhs = pool.lhs[:, idx_keep] pool.lhs = pool.lhs[:, idx_keep]
pool.ub = pool.ub[idx_keep] pool.ub = pool.ub[idx_keep]
pool.lb = pool.lb[idx_keep] pool.lb = pool.lb[idx_keep]
pool.hash = pool.hash[idx_keep]
multipliers_curr = multipliers_curr[idx_keep] multipliers_curr = multipliers_curr[idx_keep]
multipliers_best = multipliers_best[idx_keep] multipliers_best = multipliers_best[idx_keep]
pool_cut_age = pool_cut_age[idx_keep] pool_cut_age = pool_cut_age[idx_keep]

@ -35,6 +35,7 @@ Base.@kwdef mutable struct ConstraintSet
lhs::SparseMatrixCSC lhs::SparseMatrixCSC
ub::Vector{Float64} ub::Vector{Float64}
lb::Vector{Float64} lb::Vector{Float64}
hash::Union{Nothing,Vector{UInt64}} = nothing
end end
export ProblemData, Tableau, Basis, ConstraintSet export ProblemData, Tableau, Basis, ConstraintSet

Loading…
Cancel
Save