gmi_dual: Small fixes

feature/replay^2
Alinson S. Xavier 11 months ago
parent 006ace00e7
commit 011a106d20

@ -29,7 +29,7 @@ function collect_gmi_dual(
@timeit "Read H5" begin
h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
h5 = H5File(h5_filename)
h5 = H5File(h5_filename, "r")
sol_opt_dict = Dict(
zip(
h5.get_array("static_var_names"),
@ -255,7 +255,7 @@ end
function ExpertDualGmiComponent_before_mip(test_h5, model, _)
# Read cuts and optimal solution
h5 = H5File(test_h5)
h5 = H5File(test_h5, "r")
sol_opt_dict = Dict(
zip(
h5.get_array("static_var_names"),
@ -450,70 +450,69 @@ function _dualgmi_generate(train_h5, model)
end
@timeit "Collect cuts from H5 files" begin
cut_basis_vars = nothing
cut_basis_sizes = nothing
cut_rows = nothing
vars_to_unique_basis_offset = Dict()
unique_basis_vars = nothing
unique_basis_sizes = nothing
unique_basis_rows = nothing
for h5_filename in train_h5
h5 = H5File(h5_filename)
cut_basis_vars_sample = h5.get_array("cuts_basis_vars")
cut_basis_sizes_sample = h5.get_array("cuts_basis_sizes")
cut_rows_sample = h5.get_array("cuts_rows")
if cut_basis_vars === nothing
cut_basis_vars = cut_basis_vars_sample
cut_basis_sizes = cut_basis_sizes_sample
cut_rows = cut_rows_sample
else
cut_basis_vars = [cut_basis_vars; cut_basis_vars_sample]
cut_basis_sizes = [cut_basis_sizes; cut_basis_sizes_sample]
cut_rows = [cut_rows; cut_rows_sample]
h5 = H5File(h5_filename, "r")
cut_basis_vars = h5.get_array("cuts_basis_vars")
cut_basis_sizes = h5.get_array("cuts_basis_sizes")
cut_rows = h5.get_array("cuts_rows")
ncuts, nvars = size(cut_basis_vars)
if unique_basis_vars === nothing
unique_basis_vars = Matrix{Int}(undef, 0, nvars)
unique_basis_sizes = Matrix{Int}(undef, 0, 4)
unique_basis_rows = Dict{Int,Set{Int}}()
end
h5.close()
end
ncuts, nvars = size(cut_basis_vars)
end
@timeit "Group cuts by tableau basis" begin
vars_to_unique_basis_offset = Dict()
unique_basis_vars = Matrix{Int}(undef, 0, nvars)
unique_basis_sizes = Matrix{Int}(undef, 0, 4)
unique_basis_rows = Dict{Int,Set{Int}}()
for i in 1:ncuts
vars = cut_basis_vars[i, :]
sizes = cut_basis_sizes[i, :]
row = cut_rows[i]
if vars keys(vars_to_unique_basis_offset)
offset = size(unique_basis_vars)[1] + 1
vars_to_unique_basis_offset[vars] = offset
unique_basis_vars = [unique_basis_vars; vars']
unique_basis_sizes = [unique_basis_sizes; sizes']
unique_basis_rows[offset] = Set()
for i in 1:ncuts
vars = cut_basis_vars[i, :]
sizes = cut_basis_sizes[i, :]
row = cut_rows[i]
if vars keys(vars_to_unique_basis_offset)
offset = size(unique_basis_vars)[1] + 1
vars_to_unique_basis_offset[vars] = offset
unique_basis_vars = [unique_basis_vars; vars']
unique_basis_sizes = [unique_basis_sizes; sizes']
unique_basis_rows[offset] = Set()
end
offset = vars_to_unique_basis_offset[vars]
push!(unique_basis_rows[offset], row)
end
offset = vars_to_unique_basis_offset[vars]
push!(unique_basis_rows[offset], row)
h5.close()
end
end
@timeit "Compute tableaus and cuts" begin
all_cuts = nothing
for (offset, rows) in unique_basis_rows
vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :]
current_basis = Basis(;
var_basic = unique_basis_vars[offset, 1:vbb],
var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn],
constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
)
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
cuts_s = compute_gmi(data_s, tableau)
cuts = backwards(transforms, cuts_s)
try
vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :]
current_basis = Basis(;
var_basic = unique_basis_vars[offset, 1:vbb],
var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn],
constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
)
if all_cuts === nothing
all_cuts = cuts
else
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
all_cuts.lb = [all_cuts.lb; cuts.lb]
all_cuts.ub = [all_cuts.ub; cuts.ub]
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
cuts_s = compute_gmi(data_s, tableau)
cuts = backwards(transforms, cuts_s)
if all_cuts === nothing
all_cuts = cuts
else
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
all_cuts.lb = [all_cuts.lb; cuts.lb]
all_cuts.ub = [all_cuts.ub; cuts.ub]
end
catch e
if isa(e, AssertionError)
@warn "Numerical error detected. Skipping cuts from current tableau."
continue
else
rethrow(e)
end
end
end
end
@ -555,13 +554,14 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true)
neigh_ind = neigh_ind .+ 1
N = length(neigh_dist)
k = min(N, data.k)
if data.strategy == "near"
selected = collect(1:(data.k))
selected = collect(1:k)
elseif data.strategy == "far"
selected = collect((N - data.k + 1) : N)
selected = collect((N - k + 1) : N)
elseif data.strategy == "rand"
selected = shuffle(collect(1:N))[1:(data.k)]
selected = shuffle(collect(1:N))[1:k]
else
error("unknown strategy: $(data.strategy)")
end
@ -569,7 +569,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
@info "Dual GMI: Selected neighbors ($(data.strategy)):"
neigh_dist = neigh_dist[selected]
neigh_ind = neigh_ind[selected]
for i in 1:data.k
for i in 1:k
h5_filename = data.train_h5[neigh_ind[i]]
dist = neigh_dist[i]
@info " $(h5_filename) dist=$(dist)"
@ -591,7 +591,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
print_timer()
stats = Dict()
stats["KnnDualGmi: k"] = data.k
stats["KnnDualGmi: k"] = k
stats["KnnDualGmi: strategy"] = data.strategy
stats["KnnDualGmi: cuts"] = length(cuts.lb)
stats["KnnDualGmi: time generate"] = time_generate

Loading…
Cancel
Save