gmi_dual: Small fixes

feature/replay^2
Alinson S. Xavier 11 months ago
parent 006ace00e7
commit 011a106d20

@ -29,7 +29,7 @@ function collect_gmi_dual(
@timeit "Read H5" begin @timeit "Read H5" begin
h5_filename = replace(mps_filename, ".mps.gz" => ".h5") h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
h5 = H5File(h5_filename) h5 = H5File(h5_filename, "r")
sol_opt_dict = Dict( sol_opt_dict = Dict(
zip( zip(
h5.get_array("static_var_names"), h5.get_array("static_var_names"),
@ -255,7 +255,7 @@ end
function ExpertDualGmiComponent_before_mip(test_h5, model, _) function ExpertDualGmiComponent_before_mip(test_h5, model, _)
# Read cuts and optimal solution # Read cuts and optimal solution
h5 = H5File(test_h5) h5 = H5File(test_h5, "r")
sol_opt_dict = Dict( sol_opt_dict = Dict(
zip( zip(
h5.get_array("static_var_names"), h5.get_array("static_var_names"),
@ -450,70 +450,69 @@ function _dualgmi_generate(train_h5, model)
end end
@timeit "Collect cuts from H5 files" begin @timeit "Collect cuts from H5 files" begin
cut_basis_vars = nothing vars_to_unique_basis_offset = Dict()
cut_basis_sizes = nothing unique_basis_vars = nothing
cut_rows = nothing unique_basis_sizes = nothing
unique_basis_rows = nothing
for h5_filename in train_h5 for h5_filename in train_h5
h5 = H5File(h5_filename) h5 = H5File(h5_filename, "r")
cut_basis_vars_sample = h5.get_array("cuts_basis_vars") cut_basis_vars = h5.get_array("cuts_basis_vars")
cut_basis_sizes_sample = h5.get_array("cuts_basis_sizes") cut_basis_sizes = h5.get_array("cuts_basis_sizes")
cut_rows_sample = h5.get_array("cuts_rows") cut_rows = h5.get_array("cuts_rows")
if cut_basis_vars === nothing ncuts, nvars = size(cut_basis_vars)
cut_basis_vars = cut_basis_vars_sample if unique_basis_vars === nothing
cut_basis_sizes = cut_basis_sizes_sample unique_basis_vars = Matrix{Int}(undef, 0, nvars)
cut_rows = cut_rows_sample unique_basis_sizes = Matrix{Int}(undef, 0, 4)
else unique_basis_rows = Dict{Int,Set{Int}}()
cut_basis_vars = [cut_basis_vars; cut_basis_vars_sample]
cut_basis_sizes = [cut_basis_sizes; cut_basis_sizes_sample]
cut_rows = [cut_rows; cut_rows_sample]
end end
h5.close() for i in 1:ncuts
end vars = cut_basis_vars[i, :]
ncuts, nvars = size(cut_basis_vars) sizes = cut_basis_sizes[i, :]
end row = cut_rows[i]
if vars keys(vars_to_unique_basis_offset)
@timeit "Group cuts by tableau basis" begin offset = size(unique_basis_vars)[1] + 1
vars_to_unique_basis_offset = Dict() vars_to_unique_basis_offset[vars] = offset
unique_basis_vars = Matrix{Int}(undef, 0, nvars) unique_basis_vars = [unique_basis_vars; vars']
unique_basis_sizes = Matrix{Int}(undef, 0, 4) unique_basis_sizes = [unique_basis_sizes; sizes']
unique_basis_rows = Dict{Int,Set{Int}}() unique_basis_rows[offset] = Set()
for i in 1:ncuts end
vars = cut_basis_vars[i, :] offset = vars_to_unique_basis_offset[vars]
sizes = cut_basis_sizes[i, :] push!(unique_basis_rows[offset], row)
row = cut_rows[i]
if vars keys(vars_to_unique_basis_offset)
offset = size(unique_basis_vars)[1] + 1
vars_to_unique_basis_offset[vars] = offset
unique_basis_vars = [unique_basis_vars; vars']
unique_basis_sizes = [unique_basis_sizes; sizes']
unique_basis_rows[offset] = Set()
end end
offset = vars_to_unique_basis_offset[vars] h5.close()
push!(unique_basis_rows[offset], row)
end end
end end
@timeit "Compute tableaus and cuts" begin @timeit "Compute tableaus and cuts" begin
all_cuts = nothing all_cuts = nothing
for (offset, rows) in unique_basis_rows for (offset, rows) in unique_basis_rows
vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :] try
current_basis = Basis(; vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :]
var_basic = unique_basis_vars[offset, 1:vbb], current_basis = Basis(;
var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn], var_basic = unique_basis_vars[offset, 1:vbb],
constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb], var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn],
constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn], constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
) constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
)
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
cuts_s = compute_gmi(data_s, tableau)
cuts = backwards(transforms, cuts_s)
if all_cuts === nothing tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
all_cuts = cuts cuts_s = compute_gmi(data_s, tableau)
else cuts = backwards(transforms, cuts_s)
all_cuts.lhs = [all_cuts.lhs; cuts.lhs] if all_cuts === nothing
all_cuts.lb = [all_cuts.lb; cuts.lb] all_cuts = cuts
all_cuts.ub = [all_cuts.ub; cuts.ub] else
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
all_cuts.lb = [all_cuts.lb; cuts.lb]
all_cuts.ub = [all_cuts.ub; cuts.ub]
end
catch e
if isa(e, AssertionError)
@warn "Numerical error detected. Skipping cuts from current tableau."
continue
else
rethrow(e)
end
end end
end end
end end
@ -555,13 +554,14 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true) neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true)
neigh_ind = neigh_ind .+ 1 neigh_ind = neigh_ind .+ 1
N = length(neigh_dist) N = length(neigh_dist)
k = min(N, data.k)
if data.strategy == "near" if data.strategy == "near"
selected = collect(1:(data.k)) selected = collect(1:k)
elseif data.strategy == "far" elseif data.strategy == "far"
selected = collect((N - data.k + 1) : N) selected = collect((N - k + 1) : N)
elseif data.strategy == "rand" elseif data.strategy == "rand"
selected = shuffle(collect(1:N))[1:(data.k)] selected = shuffle(collect(1:N))[1:k]
else else
error("unknown strategy: $(data.strategy)") error("unknown strategy: $(data.strategy)")
end end
@ -569,7 +569,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
@info "Dual GMI: Selected neighbors ($(data.strategy)):" @info "Dual GMI: Selected neighbors ($(data.strategy)):"
neigh_dist = neigh_dist[selected] neigh_dist = neigh_dist[selected]
neigh_ind = neigh_ind[selected] neigh_ind = neigh_ind[selected]
for i in 1:data.k for i in 1:k
h5_filename = data.train_h5[neigh_ind[i]] h5_filename = data.train_h5[neigh_ind[i]]
dist = neigh_dist[i] dist = neigh_dist[i]
@info " $(h5_filename) dist=$(dist)" @info " $(h5_filename) dist=$(dist)"
@ -591,7 +591,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
print_timer() print_timer()
stats = Dict() stats = Dict()
stats["KnnDualGmi: k"] = data.k stats["KnnDualGmi: k"] = k
stats["KnnDualGmi: strategy"] = data.strategy stats["KnnDualGmi: strategy"] = data.strategy
stats["KnnDualGmi: cuts"] = length(cuts.lb) stats["KnnDualGmi: cuts"] = length(cuts.lb)
stats["KnnDualGmi: time generate"] = time_generate stats["KnnDualGmi: time generate"] = time_generate

Loading…
Cancel
Save