mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
gmi_dual: Small fixes
This commit is contained in:
@@ -29,7 +29,7 @@ function collect_gmi_dual(
|
|||||||
|
|
||||||
@timeit "Read H5" begin
|
@timeit "Read H5" begin
|
||||||
h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
|
h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
|
||||||
h5 = H5File(h5_filename)
|
h5 = H5File(h5_filename, "r")
|
||||||
sol_opt_dict = Dict(
|
sol_opt_dict = Dict(
|
||||||
zip(
|
zip(
|
||||||
h5.get_array("static_var_names"),
|
h5.get_array("static_var_names"),
|
||||||
@@ -255,7 +255,7 @@ end
|
|||||||
|
|
||||||
function ExpertDualGmiComponent_before_mip(test_h5, model, _)
|
function ExpertDualGmiComponent_before_mip(test_h5, model, _)
|
||||||
# Read cuts and optimal solution
|
# Read cuts and optimal solution
|
||||||
h5 = H5File(test_h5)
|
h5 = H5File(test_h5, "r")
|
||||||
sol_opt_dict = Dict(
|
sol_opt_dict = Dict(
|
||||||
zip(
|
zip(
|
||||||
h5.get_array("static_var_names"),
|
h5.get_array("static_var_names"),
|
||||||
@@ -450,70 +450,69 @@ function _dualgmi_generate(train_h5, model)
|
|||||||
end
|
end
|
||||||
|
|
||||||
@timeit "Collect cuts from H5 files" begin
|
@timeit "Collect cuts from H5 files" begin
|
||||||
cut_basis_vars = nothing
|
vars_to_unique_basis_offset = Dict()
|
||||||
cut_basis_sizes = nothing
|
unique_basis_vars = nothing
|
||||||
cut_rows = nothing
|
unique_basis_sizes = nothing
|
||||||
|
unique_basis_rows = nothing
|
||||||
|
|
||||||
for h5_filename in train_h5
|
for h5_filename in train_h5
|
||||||
h5 = H5File(h5_filename)
|
h5 = H5File(h5_filename, "r")
|
||||||
cut_basis_vars_sample = h5.get_array("cuts_basis_vars")
|
cut_basis_vars = h5.get_array("cuts_basis_vars")
|
||||||
cut_basis_sizes_sample = h5.get_array("cuts_basis_sizes")
|
cut_basis_sizes = h5.get_array("cuts_basis_sizes")
|
||||||
cut_rows_sample = h5.get_array("cuts_rows")
|
cut_rows = h5.get_array("cuts_rows")
|
||||||
if cut_basis_vars === nothing
|
ncuts, nvars = size(cut_basis_vars)
|
||||||
cut_basis_vars = cut_basis_vars_sample
|
if unique_basis_vars === nothing
|
||||||
cut_basis_sizes = cut_basis_sizes_sample
|
unique_basis_vars = Matrix{Int}(undef, 0, nvars)
|
||||||
cut_rows = cut_rows_sample
|
unique_basis_sizes = Matrix{Int}(undef, 0, 4)
|
||||||
else
|
unique_basis_rows = Dict{Int,Set{Int}}()
|
||||||
cut_basis_vars = [cut_basis_vars; cut_basis_vars_sample]
|
end
|
||||||
cut_basis_sizes = [cut_basis_sizes; cut_basis_sizes_sample]
|
for i in 1:ncuts
|
||||||
cut_rows = [cut_rows; cut_rows_sample]
|
vars = cut_basis_vars[i, :]
|
||||||
|
sizes = cut_basis_sizes[i, :]
|
||||||
|
row = cut_rows[i]
|
||||||
|
if vars ∉ keys(vars_to_unique_basis_offset)
|
||||||
|
offset = size(unique_basis_vars)[1] + 1
|
||||||
|
vars_to_unique_basis_offset[vars] = offset
|
||||||
|
unique_basis_vars = [unique_basis_vars; vars']
|
||||||
|
unique_basis_sizes = [unique_basis_sizes; sizes']
|
||||||
|
unique_basis_rows[offset] = Set()
|
||||||
|
end
|
||||||
|
offset = vars_to_unique_basis_offset[vars]
|
||||||
|
push!(unique_basis_rows[offset], row)
|
||||||
end
|
end
|
||||||
h5.close()
|
h5.close()
|
||||||
end
|
end
|
||||||
ncuts, nvars = size(cut_basis_vars)
|
|
||||||
end
|
|
||||||
|
|
||||||
@timeit "Group cuts by tableau basis" begin
|
|
||||||
vars_to_unique_basis_offset = Dict()
|
|
||||||
unique_basis_vars = Matrix{Int}(undef, 0, nvars)
|
|
||||||
unique_basis_sizes = Matrix{Int}(undef, 0, 4)
|
|
||||||
unique_basis_rows = Dict{Int,Set{Int}}()
|
|
||||||
for i in 1:ncuts
|
|
||||||
vars = cut_basis_vars[i, :]
|
|
||||||
sizes = cut_basis_sizes[i, :]
|
|
||||||
row = cut_rows[i]
|
|
||||||
if vars ∉ keys(vars_to_unique_basis_offset)
|
|
||||||
offset = size(unique_basis_vars)[1] + 1
|
|
||||||
vars_to_unique_basis_offset[vars] = offset
|
|
||||||
unique_basis_vars = [unique_basis_vars; vars']
|
|
||||||
unique_basis_sizes = [unique_basis_sizes; sizes']
|
|
||||||
unique_basis_rows[offset] = Set()
|
|
||||||
end
|
|
||||||
offset = vars_to_unique_basis_offset[vars]
|
|
||||||
push!(unique_basis_rows[offset], row)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@timeit "Compute tableaus and cuts" begin
|
@timeit "Compute tableaus and cuts" begin
|
||||||
all_cuts = nothing
|
all_cuts = nothing
|
||||||
for (offset, rows) in unique_basis_rows
|
for (offset, rows) in unique_basis_rows
|
||||||
vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :]
|
try
|
||||||
current_basis = Basis(;
|
vbb, vnn, cbb, cnn = unique_basis_sizes[offset, :]
|
||||||
var_basic = unique_basis_vars[offset, 1:vbb],
|
current_basis = Basis(;
|
||||||
var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn],
|
var_basic = unique_basis_vars[offset, 1:vbb],
|
||||||
constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
|
var_nonbasic = unique_basis_vars[offset, vbb+1:vbb+vnn],
|
||||||
constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
|
constr_basic = unique_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
|
||||||
)
|
constr_nonbasic = unique_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
|
||||||
|
)
|
||||||
|
|
||||||
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
|
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
|
||||||
cuts_s = compute_gmi(data_s, tableau)
|
cuts_s = compute_gmi(data_s, tableau)
|
||||||
cuts = backwards(transforms, cuts_s)
|
cuts = backwards(transforms, cuts_s)
|
||||||
|
if all_cuts === nothing
|
||||||
if all_cuts === nothing
|
all_cuts = cuts
|
||||||
all_cuts = cuts
|
else
|
||||||
else
|
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
|
||||||
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
|
all_cuts.lb = [all_cuts.lb; cuts.lb]
|
||||||
all_cuts.lb = [all_cuts.lb; cuts.lb]
|
all_cuts.ub = [all_cuts.ub; cuts.ub]
|
||||||
all_cuts.ub = [all_cuts.ub; cuts.ub]
|
end
|
||||||
|
catch e
|
||||||
|
if isa(e, AssertionError)
|
||||||
|
@warn "Numerical error detected. Skipping cuts from current tableau."
|
||||||
|
continue
|
||||||
|
else
|
||||||
|
rethrow(e)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@@ -555,13 +554,14 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
|
|||||||
neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true)
|
neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true)
|
||||||
neigh_ind = neigh_ind .+ 1
|
neigh_ind = neigh_ind .+ 1
|
||||||
N = length(neigh_dist)
|
N = length(neigh_dist)
|
||||||
|
k = min(N, data.k)
|
||||||
|
|
||||||
if data.strategy == "near"
|
if data.strategy == "near"
|
||||||
selected = collect(1:(data.k))
|
selected = collect(1:k)
|
||||||
elseif data.strategy == "far"
|
elseif data.strategy == "far"
|
||||||
selected = collect((N - data.k + 1) : N)
|
selected = collect((N - k + 1) : N)
|
||||||
elseif data.strategy == "rand"
|
elseif data.strategy == "rand"
|
||||||
selected = shuffle(collect(1:N))[1:(data.k)]
|
selected = shuffle(collect(1:N))[1:k]
|
||||||
else
|
else
|
||||||
error("unknown strategy: $(data.strategy)")
|
error("unknown strategy: $(data.strategy)")
|
||||||
end
|
end
|
||||||
@@ -569,7 +569,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
|
|||||||
@info "Dual GMI: Selected neighbors ($(data.strategy)):"
|
@info "Dual GMI: Selected neighbors ($(data.strategy)):"
|
||||||
neigh_dist = neigh_dist[selected]
|
neigh_dist = neigh_dist[selected]
|
||||||
neigh_ind = neigh_ind[selected]
|
neigh_ind = neigh_ind[selected]
|
||||||
for i in 1:data.k
|
for i in 1:k
|
||||||
h5_filename = data.train_h5[neigh_ind[i]]
|
h5_filename = data.train_h5[neigh_ind[i]]
|
||||||
dist = neigh_dist[i]
|
dist = neigh_dist[i]
|
||||||
@info " $(h5_filename) dist=$(dist)"
|
@info " $(h5_filename) dist=$(dist)"
|
||||||
@@ -591,7 +591,7 @@ function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _
|
|||||||
print_timer()
|
print_timer()
|
||||||
|
|
||||||
stats = Dict()
|
stats = Dict()
|
||||||
stats["KnnDualGmi: k"] = data.k
|
stats["KnnDualGmi: k"] = k
|
||||||
stats["KnnDualGmi: strategy"] = data.strategy
|
stats["KnnDualGmi: strategy"] = data.strategy
|
||||||
stats["KnnDualGmi: cuts"] = length(cuts.lb)
|
stats["KnnDualGmi: cuts"] = length(cuts.lb)
|
||||||
stats["KnnDualGmi: time generate"] = time_generate
|
stats["KnnDualGmi: time generate"] = time_generate
|
||||||
|
|||||||
Reference in New Issue
Block a user