gmi_dual: Small fixes

This commit is contained in:
2024-10-17 09:36:46 -05:00
parent 006ace00e7
commit 011a106d20

View File

@@ -29,7 +29,7 @@ function collect_gmi_dual(
@timeit "Read H5" begin @timeit "Read H5" begin
h5_filename = replace(mps_filename, ".mps.gz" => ".h5") h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
h5 = H5File(h5_filename) h5 = H5File(h5_filename, "r")
sol_opt_dict = Dict( sol_opt_dict = Dict(
zip( zip(
h5.get_array("static_var_names"), h5.get_array("static_var_names"),
@@ -255,7 +255,7 @@ end
function ExpertDualGmiComponent_before_mip(test_h5, model, _) function ExpertDualGmiComponent_before_mip(test_h5, model, _)
# Read cuts and optimal solution # Read cuts and optimal solution
h5 = H5File(test_h5) h5 = H5File(test_h5, "r")
sol_opt_dict = Dict( sol_opt_dict = Dict(
zip( zip(
h5.get_array("static_var_names"), h5.get_array("static_var_names"),
@@ -450,70 +450,69 @@ function _dualgmi_generate(train_h5, model)
end end
@timeit "Collect cuts from H5 files" begin @timeit "Collect cuts from H5 files" begin
cut_basis_vars = nothing vars_to_unique_basis_offset = Dict()
cut_basis_sizes = nothing unique_basis_vars = nothing
cut_rows = nothing unique_basis_sizes = nothing
unique_basis_rows = nothing
for h5_filename in train_h5 for h5_filename in train_h5
h5 = H5File(h5_filename) h5 = H5File(h5_filename, "r")
cut_basis_vars_sample = h5.get_array("cuts_basis_vars") cut_basis_vars = h5.get_array("cuts_basis_vars")
cut_basis_sizes_sample = h5.get_array("cuts_basis_sizes") cut_basis_sizes = h5.get_array("cuts_basis_sizes")
cut_rows_sample = h5.get_array("cuts_rows") cut_rows = h5.get_array("cuts_rows")
if cut_basis_vars === nothing ncuts, nvars = size(cut_basis_vars)
cut_basis_vars = cut_basis_vars_sample if unique_basis_vars === nothing
cut_basis_sizes = cut_basis_sizes_sample unique_basis_vars = Matrix{Int}(undef, 0, nvars)
cut_rows = cut_rows_sample unique_basis_sizes = Matrix{Int}(undef, 0, 4)
else unique_basis_rows = Dict{Int,Set{Int}}()
cut_basis_vars = [cut_basis_vars; cut_basis_vars_sample] end
cut_basis_sizes = [cut_basis_sizes; cut_basis_sizes_sample] for i in 1:ncuts
cut_rows = [cut_rows; cut_rows_sample] vars = cut_basis_vars[i, :]
sizes = cut_basis_sizes[i, :]
row = cut_rows[i]
if vars keys(vars_to_unique_basis_offset)
offset = size(unique_basis_vars)[1] + 1
vars_to_unique_basis_offset[vars] = offset
unique_basis_vars = [unique_basis_vars; vars']
unique_basis_sizes = [unique_basis_sizes; sizes']
unique_basis_rows[offset] = Set()
end
offset = vars_to_unique_basis_offset[vars]
push!(unique_basis_rows[offset], row)
end end
h5.close() h5.close()
end end
ncuts, nvars = size(cut_basis_vars)
end
@timeit "Group cuts by tableau basis" begin
vars_to_unique_basis_offset = Dict()
unique_basis_vars = Matrix{Int}(undef, 0, nvars)
unique_basis_sizes = Matrix{Int}(undef, 0, 4)
unique_basis_rows = Dict{Int,Set{Int}}()
for i in 1:ncuts
vars = cut_basis_vars[i, :]
sizes = cut_basis_sizes[i, :]
row = cut_rows[i]
if vars keys(vars_to_unique_basis_offset)
offset = size(unique_basis_vars)[1] + 1
vars_to_unique_basis_offset[vars] = offset
unique_basis_vars = [unique_basis_vars; vars']
unique_basis_sizes = [unique_basis_sizes; sizes']
unique_basis_rows[offset] = Set()
end
offset = vars_to_unique_basis_offset[vars]
push!(unique_basis_rows[offset], row)
end
end end
@timeit "Compute tableaus and cuts" begin @timeit "Compute tableaus and cuts" begin
all_cuts = nothing all_cuts = nothing
for (offset, rows) in unique_basis_rows for (offset, rows) in unique_basis_rows
vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :] try
current_basis = Basis(; vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :]
var_basic = unique_basis_vars[offset, 1:vbb], current_basis = Basis(;
var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn], var_basic = unique_basis_vars[offset, 1:vbb],
constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb], var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn],
constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn], constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
) constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
)
tableau = compute_tableau(data_s, current_basis; rows=collect(rows)) tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
cuts_s = compute_gmi(data_s, tableau) cuts_s = compute_gmi(data_s, tableau)
cuts = backwards(transforms, cuts_s) cuts = backwards(transforms, cuts_s)
if all_cuts === nothing
if all_cuts === nothing all_cuts = cuts
all_cuts = cuts else
else all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
all_cuts.lhs = [all_cuts.lhs; cuts.lhs] all_cuts.lb = [all_cuts.lb; cuts.lb]
all_cuts.lb = [all_cuts.lb; cuts.lb] all_cuts.ub = [all_cuts.ub; cuts.ub]
all_cuts.ub = [all_cuts.ub; cuts.ub] end
catch e
if isa(e, AssertionError)
@warn "Numerical error detected. Skipping cuts from current tableau."
continue
else
rethrow(e)
end
end end
end end
end end
@@ -555,13 +554,14 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true) neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true)
neigh_ind = neigh_ind .+ 1 neigh_ind = neigh_ind .+ 1
N = length(neigh_dist) N = length(neigh_dist)
k = min(N, data.k)
if data.strategy == "near" if data.strategy == "near"
selected = collect(1:(data.k)) selected = collect(1:k)
elseif data.strategy == "far" elseif data.strategy == "far"
selected = collect((N - data.k + 1) : N) selected = collect((N - k + 1) : N)
elseif data.strategy == "rand" elseif data.strategy == "rand"
selected = shuffle(collect(1:N))[1:(data.k)] selected = shuffle(collect(1:N))[1:k]
else else
error("unknown strategy: $(data.strategy)") error("unknown strategy: $(data.strategy)")
end end
@@ -569,7 +569,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
@info "Dual GMI: Selected neighbors ($(data.strategy)):" @info "Dual GMI: Selected neighbors ($(data.strategy)):"
neigh_dist = neigh_dist[selected] neigh_dist = neigh_dist[selected]
neigh_ind = neigh_ind[selected] neigh_ind = neigh_ind[selected]
for i in 1:data.k for i in 1:k
h5_filename = data.train_h5[neigh_ind[i]] h5_filename = data.train_h5[neigh_ind[i]]
dist = neigh_dist[i] dist = neigh_dist[i]
@info " $(h5_filename) dist=$(dist)" @info " $(h5_filename) dist=$(dist)"
@@ -591,7 +591,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
print_timer() print_timer()
stats = Dict() stats = Dict()
stats["KnnDualGmi: k"] = data.k stats["KnnDualGmi: k"] = k
stats["KnnDualGmi: strategy"] = data.strategy stats["KnnDualGmi: strategy"] = data.strategy
stats["KnnDualGmi: cuts"] = length(cuts.lb) stats["KnnDualGmi: cuts"] = length(cuts.lb)
stats["KnnDualGmi: time generate"] = time_generate stats["KnnDualGmi: time generate"] = time_generate