mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 08:28:52 -06:00
DualGMI: multiple fixes
This commit is contained in:
@@ -264,6 +264,10 @@ function collect_gmi_dual(
|
||||
)
|
||||
end
|
||||
|
||||
# TODO:
|
||||
# blp-ic98
|
||||
# neos-3627168-kasai
|
||||
|
||||
function collect_gmi_FisSal2011(
|
||||
mps_filename;
|
||||
interval_print_sec = 1,
|
||||
@@ -271,8 +275,9 @@ function collect_gmi_FisSal2011(
|
||||
max_pool_size_mb = 1024,
|
||||
optimizer,
|
||||
silent_solver = true,
|
||||
time_limit = 14_400,
|
||||
time_limit = 300,
|
||||
variant = :miplearn,
|
||||
verify_cuts = true,
|
||||
)
|
||||
variant in [:subg, :hybr, :fast, :faster, :miplearn] || error("unknown variant: $variant")
|
||||
if variant == :subg
|
||||
@@ -395,6 +400,16 @@ function collect_gmi_FisSal2011(
|
||||
for round = 1:max_rounds
|
||||
log_prefix = ' '
|
||||
log_should_print = false
|
||||
is_last_iteration = false
|
||||
if round == max_rounds
|
||||
is_last_iteration = true
|
||||
end
|
||||
|
||||
elapsed_time = time() - initial_time
|
||||
if elapsed_time > time_limit
|
||||
@info "Time limit exceeded. Stopping after current iteration."
|
||||
is_last_iteration = true
|
||||
end
|
||||
|
||||
if round > 1
|
||||
@timeit "Build Lagrangian term" begin
|
||||
@@ -430,7 +445,12 @@ function collect_gmi_FisSal2011(
|
||||
optimize!(model_s)
|
||||
basis_cache = get_basis(model_s)
|
||||
status = termination_status(model_s)
|
||||
if status != MOI.OPTIMAL
|
||||
if status == MOI.DUAL_INFEASIBLE
|
||||
@warn "LP is unbounded (dual infeasible). Resetting to best known multipliers."
|
||||
copy!(multipliers_curr, multipliers_best)
|
||||
obj_curr = obj_best
|
||||
continue
|
||||
elseif status != MOI.OPTIMAL
|
||||
error("Non-optimal termination status: $status")
|
||||
end
|
||||
sol_frac = get_x(model_s)
|
||||
@@ -602,11 +622,15 @@ function collect_gmi_FisSal2011(
|
||||
pool_size_mb = Base.summarysize(pool) / 1024^2
|
||||
while pool_size_mb >= max_pool_size_mb
|
||||
@timeit "Identify cuts to remove" begin
|
||||
σ = sortperm(pool_cut_age, rev=true)
|
||||
scores = collect(zip(multipliers_best .> 1e-6, -pool_cut_age))
|
||||
σ = sortperm(scores, rev=true)
|
||||
pool_size = length(pool.ub)
|
||||
n_keep = Int(floor(pool_size * 0.8))
|
||||
idx_keep = σ[1:n_keep]
|
||||
idx_remove = σ[(n_keep+1):end]
|
||||
|
||||
positive_multipliers_dropped = sum(multipliers_best[idx_remove] .> 1e-6)
|
||||
@info "Dropping $(length(idx_remove)) cuts ($(positive_multipliers_dropped) with multipliers_best)"
|
||||
end
|
||||
@timeit "Update cut hashes" begin
|
||||
for idx in idx_remove
|
||||
@@ -641,7 +665,7 @@ function collect_gmi_FisSal2011(
|
||||
end
|
||||
end
|
||||
|
||||
if mod(round - 1, interval_large_lp) == 0 || round == max_rounds
|
||||
if mod(round - 1, interval_large_lp) == 0 || is_last_iteration
|
||||
log_should_print = true
|
||||
@timeit "Update multipliers (large LP)" begin
|
||||
selected_idx = []
|
||||
@@ -756,6 +780,10 @@ function collect_gmi_FisSal2011(
|
||||
log_should_print = true
|
||||
end
|
||||
|
||||
if is_last_iteration
|
||||
log_should_print = true
|
||||
end
|
||||
|
||||
if log_should_print
|
||||
last_print_time = time()
|
||||
@printf(
|
||||
@@ -783,17 +811,18 @@ function collect_gmi_FisSal2011(
|
||||
end
|
||||
end
|
||||
|
||||
elapsed_time = time() - initial_time
|
||||
if elapsed_time > time_limit
|
||||
@info "Time limit exceeded. Stopping."
|
||||
if is_last_iteration
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@info "Best gap closure: $(gapcl_best)"
|
||||
|
||||
@timeit "Keep only active cuts" begin
|
||||
positive_idx = findall(multipliers_best .> 1e-6)
|
||||
if length(positive_idx) == 0 && gapcl_best > 0
|
||||
error("gap closure with zero cuts")
|
||||
end
|
||||
|
||||
@timeit "Clean up cut pool" begin
|
||||
pool.lhs = pool.lhs[:, positive_idx]
|
||||
@@ -821,41 +850,94 @@ function collect_gmi_FisSal2011(
|
||||
@info "Keeping $(length(positive_idx)) cuts from $(length(used_basis_ids)) unique bases"
|
||||
end
|
||||
|
||||
@timeit "Write cuts to H5" begin
|
||||
if !isempty(cut_basis_id)
|
||||
@timeit "Convert IDs to offsets" begin
|
||||
id_to_offset = Dict{Int, Int}()
|
||||
gmi_basis_vars = []
|
||||
gmi_basis_sizes = []
|
||||
for (offset, basis_id) in enumerate(sort(collect(keys(basis_id_to_vars))))
|
||||
id_to_offset[basis_id] = offset
|
||||
push!(gmi_basis_vars, basis_id_to_vars[basis_id])
|
||||
push!(gmi_basis_sizes, basis_id_to_sizes[basis_id])
|
||||
end
|
||||
gmi_cut_basis = [id_to_offset[basis_id] for basis_id in cut_basis_id]
|
||||
gmi_cut_row = cut_row
|
||||
end
|
||||
|
||||
@timeit "Convert to matrices" begin
|
||||
gmi_basis_vars_matrix = hcat(gmi_basis_vars...)'
|
||||
gmi_basis_sizes_matrix = hcat(gmi_basis_sizes...)'
|
||||
end
|
||||
|
||||
@timeit "Write H5" begin
|
||||
h5 = H5File(h5_filename, "r+")
|
||||
h5.put_array("gmi_basis_vars", gmi_basis_vars_matrix)
|
||||
h5.put_array("gmi_basis_sizes", gmi_basis_sizes_matrix)
|
||||
h5.put_array("gmi_cut_basis", gmi_cut_basis)
|
||||
h5.put_array("gmi_cut_row", gmi_cut_row)
|
||||
h5.file.close()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
to = TimerOutputs.get_defaulttimer()
|
||||
stats_time = TimerOutputs.tottime(to) / 1e9
|
||||
print_timer()
|
||||
|
||||
if length(positive_idx) > 0
|
||||
@timeit "Write cuts to H5" begin
|
||||
if !isempty(cut_basis_id)
|
||||
@timeit "Convert IDs to offsets" begin
|
||||
id_to_offset = Dict{Int, Int}()
|
||||
gmi_basis_vars = []
|
||||
gmi_basis_sizes = []
|
||||
for (offset, basis_id) in enumerate(sort(collect(keys(basis_id_to_vars))))
|
||||
id_to_offset[basis_id] = offset
|
||||
push!(gmi_basis_vars, basis_id_to_vars[basis_id])
|
||||
push!(gmi_basis_sizes, basis_id_to_sizes[basis_id])
|
||||
end
|
||||
gmi_cut_basis = [id_to_offset[basis_id] for basis_id in cut_basis_id]
|
||||
gmi_cut_row = cut_row
|
||||
end
|
||||
|
||||
@timeit "Convert to matrices" begin
|
||||
gmi_basis_vars_matrix = hcat(gmi_basis_vars...)'
|
||||
gmi_basis_sizes_matrix = hcat(gmi_basis_sizes...)'
|
||||
end
|
||||
|
||||
@timeit "Write H5" begin
|
||||
h5 = H5File(h5_filename, "r+")
|
||||
h5.put_array("gmi_basis_vars", gmi_basis_vars_matrix)
|
||||
h5.put_array("gmi_basis_sizes", gmi_basis_sizes_matrix)
|
||||
h5.put_array("gmi_cut_basis", gmi_cut_basis)
|
||||
h5.put_array("gmi_cut_row", gmi_cut_row)
|
||||
h5.file.close()
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if verify_cuts
|
||||
@timeit "Verify cuts in current model" begin
|
||||
@info "Verifying cuts in current standard form model using pool..."
|
||||
if !isempty(cut_basis_id)
|
||||
@info "Adding $(length(pool.lb)) cuts from pool to current model"
|
||||
pool.lhs = sparse(pool.lhs')
|
||||
constrs = build_constraints(model_s, pool)
|
||||
add_constraint.(model_s, constrs)
|
||||
set_objective_function(model_s, orig_obj_s)
|
||||
optimize!(model_s)
|
||||
status = termination_status(model_s)
|
||||
if status != MOI.OPTIMAL
|
||||
error("Non-optimal termination status: $status")
|
||||
end
|
||||
obj_verify_s = objective_value(model_s)
|
||||
gapcl_verify_s = gapcl(obj_verify_s)
|
||||
@show gapcl_verify_s
|
||||
@show gapcl_best
|
||||
if abs(gapcl_best - gapcl_verify_s) > 0.01
|
||||
error("Gap closures differ: $(gapcl_best) ≠ $(gapcl_verify_s)")
|
||||
end
|
||||
@info "Current model gap closure matches: $(gapcl_best) ≈ $(gapcl_verify_s)"
|
||||
else
|
||||
@warn "No cuts in pool to verify"
|
||||
end
|
||||
end
|
||||
|
||||
@timeit "Verify stored cuts" begin
|
||||
@info "Verifying stored cuts..."
|
||||
model_verify = read_from_file(mps_filename)
|
||||
set_optimizer(model_verify, optimizer)
|
||||
verification_cuts = _dualgmi_generate([h5_filename], model_verify; test_h5=h5_filename)
|
||||
constrs = build_constraints(model_verify, verification_cuts)
|
||||
add_constraint.(model_verify, constrs)
|
||||
relax_integrality(model_verify)
|
||||
optimize!(model_verify)
|
||||
status = termination_status(model_verify)
|
||||
if status != MOI.OPTIMAL
|
||||
error("Non-optimal termination status: $status")
|
||||
end
|
||||
obj_verify = objective_value(model_verify)
|
||||
gapcl_verify = gapcl(obj_verify)
|
||||
@show gapcl_verify
|
||||
@show gapcl_best
|
||||
if abs(gapcl_best - gapcl_verify) > 0.01
|
||||
error("Gap closures differ: $(gapcl_best) ≠ $(gapcl_verify)")
|
||||
end
|
||||
@info "Gap closure matches gapcl_best: $(gapcl_best) ≈ $(gapcl_verify)"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return OrderedDict(
|
||||
"gapcl_best" => gapcl_best,
|
||||
"gapcl_curr" => gapcl_curr,
|
||||
@@ -977,13 +1059,34 @@ function _dualgmi_compress_h5(h5_filename)
|
||||
h5.file.close()
|
||||
end
|
||||
|
||||
function _dualgmi_generate(train_h5, model)
|
||||
function _dualgmi_generate(train_h5, model; test_h5=nothing)
|
||||
@timeit "Read problem data" begin
|
||||
data = ProblemData(model)
|
||||
end
|
||||
@timeit "Convert to standard form" begin
|
||||
data_s, transforms = convert_to_standard_form(data)
|
||||
end
|
||||
@timeit "Read optimal solution from test H5" begin
|
||||
sol_opt_dict = nothing
|
||||
sol_opt = nothing
|
||||
sol_opt_s = nothing
|
||||
if test_h5 !== nothing
|
||||
try
|
||||
h5 = H5File(test_h5, "r")
|
||||
var_names = h5.get_array("static_var_names")
|
||||
var_values = h5.get_array("mip_var_values")
|
||||
h5.close()
|
||||
if var_names !== nothing && var_values !== nothing
|
||||
sol_opt_dict = Dict(zip(var_names, convert(Array{Float64}, var_values)))
|
||||
sol_opt = [sol_opt_dict[n] for n in data.var_names]
|
||||
sol_opt_s = forward(transforms, sol_opt)
|
||||
@info "Loaded optimal solution for cut validation"
|
||||
end
|
||||
catch e
|
||||
@warn "Could not read optimal solution from test H5 file: $e"
|
||||
end
|
||||
end
|
||||
end
|
||||
@timeit "Collect cuts from H5 files" begin
|
||||
basis_vars_to_basis_offset = Dict()
|
||||
combined_basis_sizes = nothing
|
||||
@@ -1052,6 +1155,10 @@ function _dualgmi_generate(train_h5, model)
|
||||
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
|
||||
cuts_s = compute_gmi(data_s, tableau)
|
||||
cuts = backwards(transforms, cuts_s)
|
||||
if sol_opt_s !== nothing && sol_opt !== nothing
|
||||
assert_does_not_cut_off(cuts_s, sol_opt_s)
|
||||
assert_does_not_cut_off(cuts, sol_opt)
|
||||
end
|
||||
if all_cuts === nothing
|
||||
all_cuts = cuts
|
||||
else
|
||||
|
||||
@@ -30,7 +30,7 @@ function assert_cuts_off(cuts::ConstraintSet, x::Vector{Float64}, tol = 1e-6)
|
||||
vals = cuts.lhs * x
|
||||
for i = 1:length(cuts.lb)
|
||||
if (vals[i] <= cuts.ub[i] - tol) && (vals[i] >= cuts.lb[i] + tol)
|
||||
throw(ErrorException("inequality fails to cut off fractional solution"))
|
||||
throw(ErrorException("inequality $i fails to cut off fractional solution: $(cuts.lb[i]) <= $(vals[i]) <= $(cuts.ub[i])"))
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -39,7 +39,7 @@ function assert_does_not_cut_off(cuts::ConstraintSet, x::Vector{Float64}; tol =
|
||||
vals = cuts.lhs * x
|
||||
for i = 1:length(cuts.lb)
|
||||
if (vals[i] >= cuts.ub[i]) || (vals[i] <= cuts.lb[i])
|
||||
throw(ErrorException("inequality $i cuts off integer solution"))
|
||||
throw(ErrorException("inequality $i cuts off integer solution: $(cuts.lb[i]) <= $(vals[i]) <= $(cuts.ub[i])"))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -173,6 +173,10 @@ function compute_tableau(
|
||||
tol = 1e-8,
|
||||
estimated_density = 0.10,
|
||||
)::Tableau
|
||||
if isnan(estimated_density) || estimated_density <= 0
|
||||
estimated_density = 0.10
|
||||
end
|
||||
|
||||
@timeit "Split data" begin
|
||||
nrows, ncols = size(data.constr_lhs)
|
||||
lhs_slacks = sparse(I, nrows, nrows)
|
||||
|
||||
@@ -231,10 +231,10 @@ function backwards!(t::SplitFreeVars, c::ConstraintSet)
|
||||
# Assert only LE constraints are left (EQ constraints are not supported)
|
||||
@assert all(c.lb .== -Inf)
|
||||
|
||||
# Take minimum (weakest) coefficient
|
||||
# Combine split free variable coefficients: x = x_p - x_m
|
||||
B, F = t.B, t.F
|
||||
for i = 1:F
|
||||
c.lhs[:, B+i] = min.(c.lhs[:, B+i], -c.lhs[:, B+F+i])
|
||||
c.lhs[:, B+i] = c.lhs[:, B+i] - c.lhs[:, B+F+i]
|
||||
end
|
||||
c.lhs = c.lhs[:, 1:(B+F)]
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user