DualGMI: Fix type errors

dev
Alinson S. Xavier 2 months ago
parent c3a8fa6a08
commit 6c903d0b19

@ -322,6 +322,9 @@ function _dualgmi_compress_h5(h5_filename)
orig_cut_basis_vars = h5.get_array("cuts_basis_vars")
orig_cut_basis_sizes = h5.get_array("cuts_basis_sizes")
orig_cut_rows = h5.get_array("cuts_rows")
if orig_cut_basis_vars === nothing
return
end
ncuts, _ = size(orig_cut_basis_vars)
h5.close()
@ -360,9 +363,11 @@ function _dualgmi_generate(train_h5, model)
end
@timeit "Collect cuts from H5 files" begin
basis_vars_to_basis_offset = Dict()
combined_basis_vars = []
combined_basis_sizes = []
combined_cut_rows = []
combined_basis_sizes = nothing
combined_basis_sizes_list = Any[]
combined_basis_vars = nothing
combined_basis_vars_list = Any[]
combined_cut_rows = Any[]
for h5_filename in train_h5
@timeit "get_array (new)" begin
h5 = H5File(h5_filename, "r")
@ -379,10 +384,10 @@ function _dualgmi_generate(train_h5, model)
vars = gmi_basis_vars[local_offset, :]
sizes = gmi_basis_sizes[local_offset, :]
if vars keys(basis_vars_to_basis_offset)
combined_offset = length(combined_basis_vars) + 1
combined_offset = length(combined_basis_vars_list) + 1
basis_vars_to_basis_offset[vars] = combined_offset
push!(combined_basis_vars, vars)
push!(combined_basis_sizes, sizes)
push!(combined_basis_vars_list, vars)
push!(combined_basis_sizes_list, sizes)
push!(combined_cut_rows, Set{Int}())
end
combined_offset = basis_vars_to_basis_offset[vars]
@ -399,8 +404,8 @@ function _dualgmi_generate(train_h5, model)
end
end
@timeit "convert lists to matrices" begin
combined_basis_vars = hcat(combined_basis_vars...)'
combined_basis_sizes = hcat(combined_basis_sizes...)'
combined_basis_vars = hcat(combined_basis_vars_list...)'
combined_basis_sizes = hcat(combined_basis_sizes_list...)'
end
end
end

Loading…
Cancel
Save