mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 08:28:52 -06:00
DualGMI: Fix type errors
This commit is contained in:
@@ -322,6 +322,9 @@ function _dualgmi_compress_h5(h5_filename)
|
||||
orig_cut_basis_vars = h5.get_array("cuts_basis_vars")
|
||||
orig_cut_basis_sizes = h5.get_array("cuts_basis_sizes")
|
||||
orig_cut_rows = h5.get_array("cuts_rows")
|
||||
if orig_cut_basis_vars === nothing
|
||||
return
|
||||
end
|
||||
ncuts, _ = size(orig_cut_basis_vars)
|
||||
h5.close()
|
||||
|
||||
@@ -360,9 +363,11 @@ function _dualgmi_generate(train_h5, model)
|
||||
end
|
||||
@timeit "Collect cuts from H5 files" begin
|
||||
basis_vars_to_basis_offset = Dict()
|
||||
combined_basis_vars = []
|
||||
combined_basis_sizes = []
|
||||
combined_cut_rows = []
|
||||
combined_basis_sizes = nothing
|
||||
combined_basis_sizes_list = Any[]
|
||||
combined_basis_vars = nothing
|
||||
combined_basis_vars_list = Any[]
|
||||
combined_cut_rows = Any[]
|
||||
for h5_filename in train_h5
|
||||
@timeit "get_array (new)" begin
|
||||
h5 = H5File(h5_filename, "r")
|
||||
@@ -379,10 +384,10 @@ function _dualgmi_generate(train_h5, model)
|
||||
vars = gmi_basis_vars[local_offset, :]
|
||||
sizes = gmi_basis_sizes[local_offset, :]
|
||||
if vars ∉ keys(basis_vars_to_basis_offset)
|
||||
combined_offset = length(combined_basis_vars) + 1
|
||||
combined_offset = length(combined_basis_vars_list) + 1
|
||||
basis_vars_to_basis_offset[vars] = combined_offset
|
||||
push!(combined_basis_vars, vars)
|
||||
push!(combined_basis_sizes, sizes)
|
||||
push!(combined_basis_vars_list, vars)
|
||||
push!(combined_basis_sizes_list, sizes)
|
||||
push!(combined_cut_rows, Set{Int}())
|
||||
end
|
||||
combined_offset = basis_vars_to_basis_offset[vars]
|
||||
@@ -399,8 +404,8 @@ function _dualgmi_generate(train_h5, model)
|
||||
end
|
||||
end
|
||||
@timeit "convert lists to matrices" begin
|
||||
combined_basis_vars = hcat(combined_basis_vars...)'
|
||||
combined_basis_sizes = hcat(combined_basis_sizes...)'
|
||||
combined_basis_vars = hcat(combined_basis_vars_list...)'
|
||||
combined_basis_sizes = hcat(combined_basis_sizes_list...)'
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user