mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
1294 lines
48 KiB
Julia
1294 lines
48 KiB
Julia
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
||
# Copyright (C) 2020-2023, UChicago Argonne, LLC. All rights reserved.
|
||
# Released under the modified BSD license. See COPYING.md for more details.
|
||
|
||
using Printf
|
||
using JuMP
|
||
using HiGHS
|
||
using Random
|
||
using DataStructures
|
||
using Statistics
|
||
|
||
import ..H5FieldsExtractor
|
||
|
||
global ExpertDualGmiComponent = PyNULL()
|
||
global KnnDualGmiComponent = PyNULL()
|
||
|
||
Base.@kwdef mutable struct _KnnDualGmiData
|
||
k = nothing
|
||
extractor = nothing
|
||
train_h5 = nothing
|
||
model = nothing
|
||
strategy = nothing
|
||
end
|
||
|
||
function collect_gmi_dual(
|
||
mps_filename;
|
||
optimizer,
|
||
max_rounds = 10,
|
||
max_cuts_per_round = 1_000_000,
|
||
time_limit = 3_600,
|
||
)
|
||
reset_timer!()
|
||
initial_time = time()
|
||
|
||
@timeit "Read H5" begin
|
||
h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
|
||
h5 = H5File(h5_filename, "r")
|
||
sol_opt_dict = Dict(
|
||
zip(
|
||
h5.get_array("static_var_names"),
|
||
convert(Array{Float64}, h5.get_array("mip_var_values")),
|
||
),
|
||
)
|
||
obj_mip = h5.get_scalar("mip_obj_value")
|
||
h5.file.close()
|
||
end
|
||
|
||
# Define relative MIP gap
|
||
gap(v) = 100 * abs(obj_mip - v) / abs(obj_mip)
|
||
|
||
@timeit "Initialize" begin
|
||
stats_obj = []
|
||
stats_gap = []
|
||
stats_ncuts = []
|
||
original_basis = nothing
|
||
all_cuts = nothing
|
||
all_cuts_bases = nothing
|
||
all_cuts_rows = nothing
|
||
last_round_obj = nothing
|
||
end
|
||
|
||
@timeit "Read problem" begin
|
||
model = read_from_file(mps_filename)
|
||
set_optimizer(model, optimizer)
|
||
obj_original = objective_function(model)
|
||
end
|
||
|
||
for round = 1:max_rounds
|
||
@info "Round $(round)..."
|
||
|
||
@timeit "Convert model to standard form" begin
|
||
# Extract problem data
|
||
data = ProblemData(model)
|
||
|
||
# Construct optimal solution vector (with correct variable sequence)
|
||
sol_opt = [sol_opt_dict[n] for n in data.var_names]
|
||
|
||
# Assert optimal solution is feasible for the original problem
|
||
assert_leq(data.constr_lb, data.constr_lhs * sol_opt)
|
||
assert_leq(data.constr_lhs * sol_opt, data.constr_ub)
|
||
|
||
# Convert to standard form
|
||
data_s, transforms = convert_to_standard_form(data)
|
||
model_s = to_model(data_s)
|
||
set_optimizer(model_s, optimizer)
|
||
relax_integrality(model_s)
|
||
|
||
# Convert optimal solution to standard form
|
||
sol_opt_s = forward(transforms, sol_opt)
|
||
|
||
# Assert converted solution is feasible for standard form problem
|
||
assert_eq(data_s.constr_lhs * sol_opt_s, data_s.constr_lb)
|
||
end
|
||
|
||
@timeit "Optimize standard model" begin
|
||
@info "Optimizing standard model..."
|
||
optimize!(model_s)
|
||
obj = objective_value(model_s)
|
||
if round == 1
|
||
push!(stats_obj, obj)
|
||
push!(stats_gap, gap(obj))
|
||
push!(stats_ncuts, 0)
|
||
else
|
||
if obj ≈ last_round_obj
|
||
@info ("No improvement in obj value. Aborting.")
|
||
break
|
||
end
|
||
end
|
||
if termination_status(model_s) != MOI.OPTIMAL
|
||
error("Non-optimal termination status")
|
||
end
|
||
last_round_obj = obj
|
||
end
|
||
|
||
@timeit "Select tableau rows" begin
|
||
basis = get_basis(model_s)
|
||
if round == 1
|
||
original_basis = basis
|
||
end
|
||
sol_frac = get_x(model_s)
|
||
selected_rows =
|
||
select_gmi_rows(data_s, basis, sol_frac, max_rows = max_cuts_per_round)
|
||
end
|
||
|
||
@timeit "Compute tableau rows" begin
|
||
tableau = compute_tableau(data_s, basis, x = sol_frac, rows = selected_rows)
|
||
|
||
# Assert tableau rows have been computed correctly
|
||
assert_eq(tableau.lhs * sol_frac, tableau.rhs, atol=1e-3)
|
||
assert_eq(tableau.lhs * sol_opt_s, tableau.rhs, atol=1e-3)
|
||
end
|
||
|
||
@timeit "Compute GMI cuts" begin
|
||
cuts_s = compute_gmi(data_s, tableau)
|
||
|
||
# Assert cuts have been generated correctly
|
||
assert_cuts_off(cuts_s, sol_frac)
|
||
assert_does_not_cut_off(cuts_s, sol_opt_s)
|
||
|
||
# Abort if no cuts are left
|
||
if length(cuts_s.lb) == 0
|
||
@info "No cuts generated. Aborting."
|
||
break
|
||
else
|
||
@info "Generated $(length(cuts_s.lb)) cuts"
|
||
end
|
||
end
|
||
|
||
@timeit "Add GMI cuts to original model" begin
|
||
@timeit "Convert to original form" begin
|
||
cuts = backwards(transforms, cuts_s)
|
||
end
|
||
|
||
@timeit "Prepare bv" begin
|
||
bv = repeat([basis], length(selected_rows))
|
||
end
|
||
|
||
@timeit "Append matrices" begin
|
||
if round == 1
|
||
all_cuts = cuts
|
||
all_cuts_bases = bv
|
||
all_cuts_rows = selected_rows
|
||
else
|
||
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
|
||
all_cuts.lb = [all_cuts.lb; cuts.lb]
|
||
all_cuts.ub = [all_cuts.ub; cuts.ub]
|
||
all_cuts_bases = [all_cuts_bases; bv]
|
||
all_cuts_rows = [all_cuts_rows; selected_rows]
|
||
end
|
||
end
|
||
|
||
@timeit "Add to model" begin
|
||
@info "Adding $(length(all_cuts.lb)) constraints to original model"
|
||
constrs, gmi_exps = add_constraint_set_dual_v2(model, all_cuts)
|
||
end
|
||
end
|
||
|
||
@timeit "Optimize original model" begin
|
||
set_objective_function(model, obj_original)
|
||
undo_relax = relax_integrality(model)
|
||
@info "Optimizing original model (constr)..."
|
||
optimize!(model)
|
||
obj = objective_value(model)
|
||
push!(stats_obj, obj)
|
||
push!(stats_gap, gap(obj))
|
||
sp = [shadow_price(c) for c in constrs]
|
||
undo_relax()
|
||
useful = [abs(sp[i]) > 1e-6 for (i, _) in enumerate(constrs)]
|
||
keep = findall(useful .== true)
|
||
end
|
||
|
||
@timeit "Filter out useless cuts" begin
|
||
@info "Keeping $(length(keep)) useful cuts"
|
||
all_cuts.lhs = all_cuts.lhs[keep, :]
|
||
all_cuts.lb = all_cuts.lb[keep]
|
||
all_cuts.ub = all_cuts.ub[keep]
|
||
all_cuts_bases = all_cuts_bases[keep, :]
|
||
all_cuts_rows = all_cuts_rows[keep, :]
|
||
push!(stats_ncuts, length(all_cuts_rows))
|
||
if isempty(keep)
|
||
break
|
||
end
|
||
end
|
||
|
||
@timeit "Update obj function of original model" begin
|
||
delete.(model, constrs)
|
||
set_objective_function(
|
||
model,
|
||
obj_original -
|
||
sum(sp[i] * gmi_exps[i] for (i, c) in enumerate(constrs) if useful[i]),
|
||
)
|
||
end
|
||
|
||
elapsed_time = time() - initial_time
|
||
if elapsed_time > time_limit
|
||
@info "Time limit exceeded. Stopping."
|
||
break
|
||
end
|
||
end
|
||
|
||
@timeit "Store cuts in H5 file" begin
|
||
if all_cuts !== nothing
|
||
ncuts = length(all_cuts_rows)
|
||
total =
|
||
length(original_basis.var_basic) +
|
||
length(original_basis.var_nonbasic) +
|
||
length(original_basis.constr_basic) +
|
||
length(original_basis.constr_nonbasic)
|
||
all_cuts_basis_sizes = Array{Int64,2}(undef, ncuts, 4)
|
||
all_cuts_basis_vars = Array{Int64,2}(undef, ncuts, total)
|
||
for i = 1:ncuts
|
||
vb = all_cuts_bases[i].var_basic
|
||
vn = all_cuts_bases[i].var_nonbasic
|
||
cb = all_cuts_bases[i].constr_basic
|
||
cn = all_cuts_bases[i].constr_nonbasic
|
||
all_cuts_basis_sizes[i, :] = [length(vb) length(vn) length(cb) length(cn)]
|
||
all_cuts_basis_vars[i, :] = [vb' vn' cb' cn']
|
||
end
|
||
@info "Storing $(length(all_cuts.ub)) GMI cuts..."
|
||
h5 = H5File(h5_filename)
|
||
h5.put_sparse("cuts_lhs", all_cuts.lhs)
|
||
h5.put_array("cuts_lb", all_cuts.lb)
|
||
h5.put_array("cuts_ub", all_cuts.ub)
|
||
h5.put_array("cuts_basis_vars", all_cuts_basis_vars)
|
||
h5.put_array("cuts_basis_sizes", all_cuts_basis_sizes)
|
||
h5.put_array("cuts_rows", all_cuts_rows)
|
||
h5.file.close()
|
||
end
|
||
end
|
||
|
||
to = TimerOutputs.get_defaulttimer()
|
||
stats_time = TimerOutputs.tottime(to) / 1e9
|
||
print_timer()
|
||
|
||
return OrderedDict(
|
||
"instance" => mps_filename,
|
||
"max_rounds" => max_rounds,
|
||
"rounds" => length(stats_obj) - 1,
|
||
"obj_mip" => obj_mip,
|
||
"stats_obj" => stats_obj,
|
||
"stats_gap" => stats_gap,
|
||
"stats_ncuts" => stats_ncuts,
|
||
"stats_time" => stats_time,
|
||
)
|
||
end
|
||
|
||
# TODO:
|
||
# blp-ic98
|
||
# neos-3627168-kasai
|
||
|
||
function collect_gmi_FisSal2011(
|
||
mps_filename;
|
||
interval_print_sec = 1,
|
||
max_cuts_per_round = 1_000_000,
|
||
max_pool_size_mb = 1024,
|
||
optimizer,
|
||
silent_solver = true,
|
||
time_limit = 300,
|
||
variant = :miplearn,
|
||
verify_cuts = true,
|
||
)
|
||
variant in [:subg, :hybr, :fast, :faster, :miplearn] || error("unknown variant: $variant")
|
||
if variant == :subg
|
||
max_rounds = 10_000
|
||
interval_large_lp = 10_000
|
||
interval_read_tableau = 10
|
||
elseif variant == :hybr
|
||
max_rounds = 10_000
|
||
interval_large_lp = 1_000
|
||
interval_read_tableau = 10
|
||
elseif variant == :fast
|
||
max_rounds = 1_000
|
||
interval_large_lp = 100
|
||
interval_read_tableau = 1
|
||
elseif variant == :faster
|
||
max_rounds = 500
|
||
interval_large_lp = 50
|
||
interval_read_tableau = 1
|
||
elseif variant == :miplearn
|
||
max_rounds = 1_000_000
|
||
interval_large_lp = 100
|
||
interval_read_tableau = 1
|
||
end
|
||
gapcl_best_patience = 2 * interval_large_lp + 5
|
||
|
||
reset_timer!()
|
||
initial_time = time()
|
||
|
||
@timeit "Read H5" begin
|
||
h5_filename = replace(mps_filename, ".mps.gz" => ".h5")
|
||
h5 = H5File(h5_filename, "r")
|
||
sol_opt_dict = Dict(
|
||
zip(
|
||
h5.get_array("static_var_names"),
|
||
convert(Array{Float64}, h5.get_array("mip_var_values")),
|
||
),
|
||
)
|
||
obj_mip = h5.get_scalar("mip_obj_value")
|
||
h5.file.close()
|
||
end
|
||
|
||
@timeit "Initialize" begin
|
||
count_backtrack = 0
|
||
count_deterioration = 0
|
||
gapcl_best = 0
|
||
gapcl_best_history = CircularBuffer{Float64}(gapcl_best_patience)
|
||
gapcl_curr = 0
|
||
last_print_time = 0
|
||
multipliers_best = Float64[]
|
||
multipliers_curr = Float64[]
|
||
obj_best = nothing
|
||
obj_curr = nothing
|
||
obj_hist = CircularBuffer{Float64}(100)
|
||
obj_initial = nothing
|
||
pool = nothing
|
||
pool_cut_age = nothing
|
||
pool_cut_hashes = Set{UInt64}()
|
||
pool_size_mb = 0
|
||
tableau_density::Float32 = 0.05
|
||
basis_cache = nothing
|
||
λ, Δ = 0, 0
|
||
μ = 10
|
||
|
||
basis_vars_to_id = Dict()
|
||
basis_id_to_vars = Dict{Int, Vector{Int}}()
|
||
basis_id_to_sizes = Dict{Int, Vector{Int}}()
|
||
next_basis_id = 1
|
||
cut_basis_id = Int[]
|
||
cut_row = Int[]
|
||
end
|
||
|
||
gapcl(v) = 100 * (v - obj_initial) / (obj_mip - obj_initial)
|
||
|
||
@timeit "Read problem" begin
|
||
model = read_from_file(mps_filename)
|
||
set_optimizer(model, optimizer)
|
||
end
|
||
|
||
@timeit "Convert model to standard form" begin
|
||
# Extract problem data
|
||
data = ProblemData(model)
|
||
|
||
# Construct optimal solution vector (with correct variable sequence)
|
||
sol_opt = [sol_opt_dict[n] for n in data.var_names]
|
||
|
||
# Assert optimal solution is feasible for the original problem
|
||
assert_leq(data.constr_lb, data.constr_lhs * sol_opt)
|
||
assert_leq(data.constr_lhs * sol_opt, data.constr_ub)
|
||
for (var_idx, var_type) in enumerate(data.var_types)
|
||
if var_type in ['B', 'I']
|
||
assert_int(sol_opt[var_idx])
|
||
end
|
||
end
|
||
|
||
# Convert to standard form
|
||
data_s, transforms = convert_to_standard_form(data)
|
||
model_s = to_model(data_s)
|
||
if silent_solver
|
||
set_silent(model_s)
|
||
end
|
||
vars_s = all_variables(model_s)
|
||
orig_obj_s = objective_function(model_s)
|
||
set_optimizer(model_s, optimizer)
|
||
relax_integrality(model_s)
|
||
|
||
# Convert optimal solution to standard form
|
||
sol_opt_s = forward(transforms, sol_opt)
|
||
|
||
# Assert converted solution is feasible for standard form problem
|
||
for (var_idx, var_type) in enumerate(data_s.var_types)
|
||
if var_type in ['B', 'I']
|
||
assert_int(sol_opt_s[var_idx])
|
||
end
|
||
end
|
||
assert_eq(data_s.constr_lhs * sol_opt_s, data_s.constr_lb)
|
||
end
|
||
|
||
@info "Standard form model has $(length(data_s.var_lb)) vars, $(length(data_s.constr_lb)) constrs"
|
||
|
||
for round = 1:max_rounds
|
||
log_prefix = ' '
|
||
log_should_print = false
|
||
is_last_iteration = false
|
||
if round == max_rounds
|
||
is_last_iteration = true
|
||
end
|
||
|
||
elapsed_time = time() - initial_time
|
||
if elapsed_time > time_limit
|
||
@info "Time limit exceeded. Stopping after current iteration."
|
||
is_last_iteration = true
|
||
end
|
||
|
||
if round > 1
|
||
@timeit "Build Lagrangian term" begin
|
||
@timeit "mul" begin
|
||
active_idx = findall(multipliers_curr .> 1e-6)
|
||
v = sparse(pool.lhs[:, active_idx] * multipliers_curr[active_idx])
|
||
end
|
||
@timeit "dot" begin
|
||
lagr_term = AffExpr(dot(multipliers_curr, pool.lb))
|
||
end
|
||
@timeit "add_to_expression!" begin
|
||
for offset in 1:nnz(v)
|
||
var_idx = v.nzind[offset]
|
||
add_to_expression!(
|
||
lagr_term,
|
||
vars_s[var_idx],
|
||
- v.nzval[offset],
|
||
)
|
||
end
|
||
end
|
||
end
|
||
@timeit "Update objective" begin
|
||
set_objective_function(
|
||
model_s,
|
||
orig_obj_s + lagr_term,
|
||
)
|
||
end
|
||
end
|
||
|
||
@timeit "Optimize LP (lagrangian)" begin
|
||
basis_cache === nothing || set_basis(model_s, basis_cache)
|
||
set_silent(model_s)
|
||
optimize!(model_s)
|
||
basis_cache = get_basis(model_s)
|
||
status = termination_status(model_s)
|
||
if status == MOI.DUAL_INFEASIBLE
|
||
@warn "LP is unbounded (dual infeasible). Resetting to best known multipliers."
|
||
copy!(multipliers_curr, multipliers_best)
|
||
obj_curr = obj_best
|
||
continue
|
||
elseif status != MOI.OPTIMAL
|
||
error("Non-optimal termination status: $status")
|
||
end
|
||
sol_frac = get_x(model_s)
|
||
obj_curr = objective_value(model_s)
|
||
end
|
||
|
||
@timeit "Update history and μ" begin
|
||
push!(obj_hist, obj_curr)
|
||
if obj_best === nothing || obj_curr > obj_best
|
||
log_prefix = '*'
|
||
obj_best = obj_curr
|
||
copy!(multipliers_best, multipliers_curr)
|
||
end
|
||
if round == 1
|
||
obj_initial = obj_curr
|
||
end
|
||
gapcl_curr = gapcl(obj_curr)
|
||
gapcl_best = gapcl(obj_best)
|
||
push!(gapcl_best_history, gapcl_best)
|
||
if variant in [:subg, :hybr]
|
||
Δ = obj_mip - obj_best
|
||
if obj_curr < obj_best - Δ
|
||
count_deterioration += 1
|
||
else
|
||
count_deterioration = 0
|
||
end
|
||
if count_deterioration >= 10
|
||
μ *= 0.5
|
||
copy!(multipliers_curr, multipliers_best)
|
||
count_deterioration = 0
|
||
count_backtrack += 1
|
||
elseif length(obj_hist) >= 100
|
||
obj_hist_avg = mean(obj_hist)
|
||
improv = obj_best - obj_hist[1]
|
||
if improv < 0.01 * Δ
|
||
if obj_best - obj_hist_avg < 0.001 * Δ
|
||
μ = 10 * μ
|
||
elseif obj_best - obj_hist_avg < 0.01 * Δ
|
||
μ = 2 * μ
|
||
else
|
||
μ = 0.5 * μ
|
||
end
|
||
end
|
||
end
|
||
elseif variant in [:fast, :faster, :miplearn]
|
||
μ = 0.01
|
||
else
|
||
error("not implemented")
|
||
end
|
||
end
|
||
|
||
if mod(round - 1, interval_read_tableau) == 0
|
||
@timeit "Get basis" begin
|
||
basis = get_basis(model_s)
|
||
end
|
||
@timeit "Select tableau rows" begin
|
||
selected_rows =
|
||
select_gmi_rows(data_s, basis, sol_frac, max_rows = max_cuts_per_round)
|
||
end
|
||
|
||
@timeit "Compute tableau rows" begin
|
||
tableau = compute_tableau(data_s, basis, x = sol_frac, rows = selected_rows, estimated_density=tableau_density * 1.05)
|
||
tableau_density = nnz(tableau.lhs) / length(tableau.lhs)
|
||
assert_eq(tableau.lhs * sol_frac, tableau.rhs, atol=1e-3)
|
||
assert_eq(tableau.lhs * sol_opt_s, tableau.rhs, atol=1e-3)
|
||
end
|
||
|
||
@timeit "Compute GMI cuts" begin
|
||
cuts_s = compute_gmi(data_s, tableau)
|
||
end
|
||
|
||
@timeit "Check cut validity" begin
|
||
assert_cuts_off(cuts_s, sol_frac)
|
||
assert_does_not_cut_off(cuts_s, sol_opt_s)
|
||
ncuts = length(cuts_s.lb)
|
||
end
|
||
|
||
@timeit "Add new cuts to the pool" begin
|
||
@timeit "Compute cut hashses" begin
|
||
unique_indices = Int[]
|
||
for i in 1:ncuts
|
||
cut_hash = cuts_s.hash[i]
|
||
if !(cut_hash in pool_cut_hashes)
|
||
push!(pool_cut_hashes, cut_hash)
|
||
push!(unique_indices, i)
|
||
end
|
||
end
|
||
end
|
||
@timeit "Append unique cuts" begin
|
||
@timeit "Track basis" begin
|
||
vb = basis.var_basic
|
||
vn = basis.var_nonbasic
|
||
cb = basis.constr_basic
|
||
cn = basis.constr_nonbasic
|
||
basis_vars = [vb; vn; cb; cn]
|
||
basis_sizes = [length(vb), length(vn), length(cb), length(cn)]
|
||
|
||
if basis_vars ∉ keys(basis_vars_to_id)
|
||
basis_id = next_basis_id
|
||
basis_vars_to_id[basis_vars] = basis_id
|
||
basis_id_to_vars[basis_id] = basis_vars
|
||
basis_id_to_sizes[basis_id] = basis_sizes
|
||
next_basis_id += 1
|
||
else
|
||
basis_id = basis_vars_to_id[basis_vars]
|
||
end
|
||
end
|
||
|
||
if round == 1
|
||
pool = ConstraintSet(
|
||
lhs = sparse(cuts_s.lhs[unique_indices, :]'),
|
||
lb = cuts_s.lb[unique_indices],
|
||
ub = cuts_s.ub[unique_indices],
|
||
hash = cuts_s.hash[unique_indices],
|
||
)
|
||
ncuts_unique = length(unique_indices)
|
||
multipliers_curr = zeros(ncuts_unique)
|
||
multipliers_best = zeros(ncuts_unique)
|
||
pool_cut_age = zeros(ncuts_unique)
|
||
for i in unique_indices
|
||
push!(cut_basis_id, basis_id)
|
||
push!(cut_row, selected_rows[i])
|
||
end
|
||
else
|
||
if !isempty(unique_indices)
|
||
@timeit "Append LHS" begin
|
||
# Transpose cuts matrix for better performance
|
||
new_cuts_lhs = sparse(cuts_s.lhs[unique_indices, :]')
|
||
|
||
# Resize existing matrix in-place to accommodate new columns
|
||
old_cols = pool.lhs.n
|
||
new_cols = new_cuts_lhs.n
|
||
total_cols = old_cols + new_cols
|
||
resize!(pool.lhs.colptr, total_cols + 1)
|
||
|
||
# Append new column pointers with offset
|
||
old_nnz = nnz(pool.lhs)
|
||
for i in 1:new_cols
|
||
pool.lhs.colptr[old_cols + i + 1] = old_nnz + new_cuts_lhs.colptr[i + 1]
|
||
end
|
||
|
||
# Expand rowval and nzval arrays
|
||
append!(pool.lhs.rowval, new_cuts_lhs.rowval)
|
||
append!(pool.lhs.nzval, new_cuts_lhs.nzval)
|
||
|
||
# Update matrix dimensions
|
||
pool.lhs = SparseMatrixCSC(pool.lhs.m, total_cols, pool.lhs.colptr, pool.lhs.rowval, pool.lhs.nzval)
|
||
end
|
||
@timeit "Append others" begin
|
||
ncuts_unique = length(unique_indices)
|
||
append!(pool.lb, cuts_s.lb[unique_indices])
|
||
append!(pool.ub, cuts_s.ub[unique_indices])
|
||
append!(pool.hash, cuts_s.hash[unique_indices])
|
||
append!(multipliers_curr, zeros(ncuts_unique))
|
||
append!(multipliers_best, zeros(ncuts_unique))
|
||
append!(pool_cut_age, zeros(ncuts_unique))
|
||
for i in unique_indices
|
||
push!(cut_basis_id, basis_id)
|
||
push!(cut_row, selected_rows[i])
|
||
end
|
||
end
|
||
end
|
||
end
|
||
end
|
||
end
|
||
|
||
|
||
@timeit "Prune the pool" begin
|
||
pool_size_mb = Base.summarysize(pool) / 1024^2
|
||
while pool_size_mb >= max_pool_size_mb
|
||
@timeit "Identify cuts to remove" begin
|
||
scores = collect(zip(multipliers_best .> 1e-6, -pool_cut_age))
|
||
σ = sortperm(scores, rev=true)
|
||
pool_size = length(pool.ub)
|
||
n_keep = Int(floor(pool_size * 0.8))
|
||
idx_keep = σ[1:n_keep]
|
||
idx_remove = σ[(n_keep+1):end]
|
||
|
||
positive_multipliers_dropped = sum(multipliers_best[idx_remove] .> 1e-6)
|
||
@info "Dropping $(length(idx_remove)) cuts ($(positive_multipliers_dropped) with multipliers_best)"
|
||
end
|
||
@timeit "Update cut hashes" begin
|
||
for idx in idx_remove
|
||
cut_hash = pool.hash[idx]
|
||
delete!(pool_cut_hashes, cut_hash)
|
||
end
|
||
end
|
||
@timeit "Update cut pool" begin
|
||
pool.lhs = pool.lhs[:, idx_keep]
|
||
pool.ub = pool.ub[idx_keep]
|
||
pool.lb = pool.lb[idx_keep]
|
||
pool.hash = pool.hash[idx_keep]
|
||
multipliers_curr = multipliers_curr[idx_keep]
|
||
multipliers_best = multipliers_best[idx_keep]
|
||
pool_cut_age = pool_cut_age[idx_keep]
|
||
cut_basis_id = cut_basis_id[idx_keep]
|
||
cut_row = cut_row[idx_keep]
|
||
end
|
||
@timeit "Update known bases" begin
|
||
used_basis_ids = Set(cut_basis_id)
|
||
for basis_id in collect(keys(basis_id_to_vars))
|
||
if basis_id ∉ used_basis_ids
|
||
basis_vars = basis_id_to_vars[basis_id]
|
||
delete!(basis_vars_to_id, basis_vars)
|
||
delete!(basis_id_to_vars, basis_id)
|
||
delete!(basis_id_to_sizes, basis_id)
|
||
end
|
||
end
|
||
end
|
||
pool_size_mb = Base.summarysize(pool) / 1024^2
|
||
end
|
||
end
|
||
end
|
||
|
||
if mod(round - 1, interval_large_lp) == 0 || is_last_iteration
|
||
log_should_print = true
|
||
@timeit "Update multipliers (large LP)" begin
|
||
selected_idx = []
|
||
selected_contrs = []
|
||
while true
|
||
@timeit "Optimize LP (extended)" begin
|
||
set_silent(model_s)
|
||
set_objective_function(model_s, orig_obj_s)
|
||
optimize!(model_s)
|
||
status = termination_status(model_s)
|
||
if status != MOI.OPTIMAL
|
||
error("Non-optimal termination status: $status")
|
||
end
|
||
obj_curr = objective_value(model_s)
|
||
sol_frac = get_x(model_s)
|
||
end
|
||
|
||
@timeit "Computing cut violations" begin
|
||
violations = pool.lb - pool.lhs' * sol_frac
|
||
end
|
||
|
||
@timeit "Sorting cut violations" begin
|
||
σ = sortperm(violations, rev=true)
|
||
end
|
||
|
||
if violations[σ[1]] <= 1e-6
|
||
break
|
||
end
|
||
|
||
@timeit "Add constraints to the model" begin
|
||
ncuts = min(max(1, sum(violations .> 1e-6) ÷ 10), length(σ))
|
||
for i in 1:ncuts
|
||
if violations[σ[i]] <= 1e-6
|
||
break
|
||
end
|
||
cut_lhs = pool.lhs[:, σ[i]]
|
||
cut_lhs_value = 0.0
|
||
cut_lb = pool.lb[σ[i]]
|
||
cut_expr = AffExpr()
|
||
for offset in 1:nnz(cut_lhs)
|
||
var_idx = cut_lhs.nzind[offset]
|
||
add_to_expression!(
|
||
cut_expr,
|
||
vars_s[var_idx],
|
||
cut_lhs.nzval[offset],
|
||
)
|
||
cut_lhs_value += sol_frac[var_idx] * cut_lhs.nzval[offset]
|
||
end
|
||
cut_constr = @constraint(model_s, cut_expr >= cut_lb)
|
||
push!(selected_idx, σ[i])
|
||
push!(selected_contrs, cut_constr)
|
||
end
|
||
end
|
||
end
|
||
|
||
@timeit "Find dual values for all selected cuts" begin
|
||
multipliers_curr .= 0
|
||
pool_cut_age .+= 1
|
||
for (offset, idx) in enumerate(selected_idx)
|
||
multipliers_curr[idx] = -shadow_price(selected_contrs[offset])
|
||
if multipliers_curr[idx] > 1e-5
|
||
pool_cut_age[idx] = 0
|
||
end
|
||
end
|
||
end
|
||
|
||
@timeit "Update best" begin
|
||
if obj_curr > obj_best
|
||
log_prefix = '*'
|
||
obj_best = obj_curr
|
||
copy!(multipliers_best, multipliers_curr)
|
||
end
|
||
gapcl_curr = gapcl(obj_curr)
|
||
gapcl_best = gapcl(obj_best)
|
||
end
|
||
|
||
@timeit "Delete all cut constraints" begin
|
||
delete.(model_s, selected_contrs)
|
||
end
|
||
end
|
||
else
|
||
@timeit "Update multipliers (subgradient)" begin
|
||
subgrad = (pool.lb' - (sol_frac' * pool.lhs))'
|
||
subgrad_norm_sq = norm(subgrad)^2
|
||
if subgrad_norm_sq < 1e-10
|
||
λ = 0
|
||
else
|
||
λ = μ * (obj_mip - obj_curr) / subgrad_norm_sq
|
||
end
|
||
multipliers_curr = max.(0, multipliers_curr .+ λ * subgrad)
|
||
end
|
||
end
|
||
|
||
if round == 1
|
||
@printf(
|
||
" %8s %8s %10s %9s %9s %9s %9s %4s %8s %8s %8s\n",
|
||
"time",
|
||
"round",
|
||
"obj",
|
||
"cl_curr",
|
||
"cl_best",
|
||
"pool_cuts",
|
||
"pool_mb",
|
||
"bktk",
|
||
"Δ",
|
||
"μ",
|
||
"λ",
|
||
)
|
||
end
|
||
|
||
if time() - last_print_time > interval_print_sec
|
||
log_should_print = true
|
||
end
|
||
|
||
if is_last_iteration
|
||
log_should_print = true
|
||
end
|
||
|
||
if log_should_print
|
||
last_print_time = time()
|
||
@printf(
|
||
"%c %8.2f %8d %10.3e %9.2e %9.2e %9d %9.2f %4d %8.2e %8.2e %8.2e\n",
|
||
log_prefix,
|
||
time() - initial_time,
|
||
round,
|
||
obj_curr,
|
||
gapcl_curr,
|
||
gapcl_best,
|
||
length(pool.ub),
|
||
pool_size_mb,
|
||
count_backtrack,
|
||
Δ,
|
||
μ,
|
||
λ,
|
||
)
|
||
end
|
||
|
||
push!(gapcl_best_history, gapcl_best)
|
||
if length(gapcl_best_history) >= gapcl_best_patience
|
||
if gapcl_best <= gapcl_best_history[1]
|
||
@info "No gap closure improvement. Stopping."
|
||
break
|
||
end
|
||
end
|
||
|
||
if is_last_iteration
|
||
break
|
||
end
|
||
end
|
||
|
||
@info "Best gap closure: $(gapcl_best)"
|
||
|
||
@timeit "Keep only active cuts" begin
|
||
positive_idx = findall(multipliers_best .> 1e-6)
|
||
if length(positive_idx) == 0 && gapcl_best > 0
|
||
error("gap closure with zero cuts")
|
||
end
|
||
|
||
@timeit "Clean up cut pool" begin
|
||
pool.lhs = pool.lhs[:, positive_idx]
|
||
pool.lb = pool.lb[positive_idx]
|
||
pool.ub = pool.ub[positive_idx]
|
||
pool.hash = pool.hash[positive_idx]
|
||
multipliers_best = multipliers_best[positive_idx]
|
||
multipliers_curr = multipliers_curr[positive_idx]
|
||
cut_basis_id = cut_basis_id[positive_idx]
|
||
cut_row = cut_row[positive_idx]
|
||
end
|
||
|
||
@timeit "Clean up known bases" begin
|
||
used_basis_ids = Set(cut_basis_id)
|
||
for basis_id in collect(keys(basis_id_to_vars))
|
||
if basis_id ∉ used_basis_ids
|
||
basis_vars = basis_id_to_vars[basis_id]
|
||
delete!(basis_vars_to_id, basis_vars)
|
||
delete!(basis_id_to_vars, basis_id)
|
||
delete!(basis_id_to_sizes, basis_id)
|
||
end
|
||
end
|
||
end
|
||
|
||
@info "Keeping $(length(positive_idx)) cuts from $(length(used_basis_ids)) unique bases"
|
||
end
|
||
|
||
to = TimerOutputs.get_defaulttimer()
|
||
stats_time = TimerOutputs.tottime(to) / 1e9
|
||
print_timer()
|
||
|
||
if length(positive_idx) > 0
|
||
@timeit "Write cuts to H5" begin
|
||
if !isempty(cut_basis_id)
|
||
@timeit "Convert IDs to offsets" begin
|
||
id_to_offset = Dict{Int, Int}()
|
||
gmi_basis_vars = []
|
||
gmi_basis_sizes = []
|
||
for (offset, basis_id) in enumerate(sort(collect(keys(basis_id_to_vars))))
|
||
id_to_offset[basis_id] = offset
|
||
push!(gmi_basis_vars, basis_id_to_vars[basis_id])
|
||
push!(gmi_basis_sizes, basis_id_to_sizes[basis_id])
|
||
end
|
||
gmi_cut_basis = [id_to_offset[basis_id] for basis_id in cut_basis_id]
|
||
gmi_cut_row = cut_row
|
||
end
|
||
|
||
@timeit "Convert to matrices" begin
|
||
gmi_basis_vars_matrix = hcat(gmi_basis_vars...)'
|
||
gmi_basis_sizes_matrix = hcat(gmi_basis_sizes...)'
|
||
end
|
||
|
||
@timeit "Write H5" begin
|
||
h5 = H5File(h5_filename, "r+")
|
||
h5.put_array("gmi_basis_vars", gmi_basis_vars_matrix)
|
||
h5.put_array("gmi_basis_sizes", gmi_basis_sizes_matrix)
|
||
h5.put_array("gmi_cut_basis", gmi_cut_basis)
|
||
h5.put_array("gmi_cut_row", gmi_cut_row)
|
||
h5.file.close()
|
||
end
|
||
end
|
||
end
|
||
|
||
if verify_cuts
|
||
@timeit "Verify cuts in current model" begin
|
||
@info "Verifying cuts in current standard form model using pool..."
|
||
if !isempty(cut_basis_id)
|
||
@info "Adding $(length(pool.lb)) cuts from pool to current model"
|
||
pool.lhs = sparse(pool.lhs')
|
||
constrs = build_constraints(model_s, pool)
|
||
add_constraint.(model_s, constrs)
|
||
set_objective_function(model_s, orig_obj_s)
|
||
optimize!(model_s)
|
||
status = termination_status(model_s)
|
||
if status != MOI.OPTIMAL
|
||
error("Non-optimal termination status: $status")
|
||
end
|
||
obj_verify_s = objective_value(model_s)
|
||
gapcl_verify_s = gapcl(obj_verify_s)
|
||
@show gapcl_verify_s
|
||
@show gapcl_best
|
||
if abs(gapcl_best - gapcl_verify_s) > 0.01
|
||
error("Gap closures differ: $(gapcl_best) ≠ $(gapcl_verify_s)")
|
||
end
|
||
@info "Current model gap closure matches: $(gapcl_best) ≈ $(gapcl_verify_s)"
|
||
else
|
||
@warn "No cuts in pool to verify"
|
||
end
|
||
end
|
||
|
||
@timeit "Verify stored cuts" begin
|
||
@info "Verifying stored cuts..."
|
||
model_verify = read_from_file(mps_filename)
|
||
set_optimizer(model_verify, optimizer)
|
||
verification_cuts = _dualgmi_generate([h5_filename], model_verify; test_h5=h5_filename)
|
||
constrs = build_constraints(model_verify, verification_cuts)
|
||
add_constraint.(model_verify, constrs)
|
||
relax_integrality(model_verify)
|
||
optimize!(model_verify)
|
||
status = termination_status(model_verify)
|
||
if status != MOI.OPTIMAL
|
||
error("Non-optimal termination status: $status")
|
||
end
|
||
obj_verify = objective_value(model_verify)
|
||
gapcl_verify = gapcl(obj_verify)
|
||
@show gapcl_verify
|
||
@show gapcl_best
|
||
if abs(gapcl_best - gapcl_verify) > 0.01
|
||
error("Gap closures differ: $(gapcl_best) ≠ $(gapcl_verify)")
|
||
end
|
||
@info "Gap closure matches gapcl_best: $(gapcl_best) ≈ $(gapcl_verify)"
|
||
end
|
||
end
|
||
end
|
||
|
||
return OrderedDict(
|
||
"gapcl_best" => gapcl_best,
|
||
"gapcl_curr" => gapcl_curr,
|
||
"instance" => mps_filename,
|
||
"obj_final" => obj_curr,
|
||
"obj_initial" => obj_initial,
|
||
"obj_mip" => obj_mip,
|
||
"pool_size_mb" => pool_size_mb,
|
||
"pool_total" => length(pool.lb),
|
||
"time" => stats_time,
|
||
)
|
||
end
|
||
|
||
function add_constraint_set_dual_v2(model::JuMP.Model, cs::ConstraintSet)
|
||
vars = all_variables(model)
|
||
nrows, ncols = size(cs.lhs)
|
||
|
||
@timeit "Transpose LHS" begin
|
||
lhs_t = spzeros(ncols, nrows)
|
||
ftranspose!(lhs_t, cs.lhs, x -> x)
|
||
lhs_t_rows = rowvals(lhs_t)
|
||
lhs_t_vals = nonzeros(lhs_t)
|
||
end
|
||
|
||
constrs = []
|
||
gmi_exps = []
|
||
for i = 1:nrows
|
||
c = nothing
|
||
gmi_exp = nothing
|
||
gmi_exp2 = nothing
|
||
@timeit "Build expr" begin
|
||
expr = AffExpr()
|
||
for k in nzrange(lhs_t, i)
|
||
add_to_expression!(expr, lhs_t_vals[k], vars[lhs_t_rows[k]])
|
||
end
|
||
end
|
||
@timeit "Add constraints" begin
|
||
if isinf(cs.ub[i])
|
||
c = @constraint(model, cs.lb[i] <= expr)
|
||
gmi_exp = cs.lb[i] - expr
|
||
elseif isinf(cs.lb[i])
|
||
c = @constraint(model, expr <= cs.ub[i])
|
||
gmi_exp = expr - cs.ub[i]
|
||
else
|
||
c = @constraint(model, cs.lb[i] <= expr <= cs.ub[i])
|
||
gmi_exp = cs.lb[i] - expr
|
||
gmi_exp2 = expr - cs.ub[i]
|
||
end
|
||
end
|
||
@timeit "Update structs" begin
|
||
push!(constrs, c)
|
||
push!(gmi_exps, gmi_exp)
|
||
if !isnothing(gmi_exp2)
|
||
push!(gmi_exps, gmi_exp2)
|
||
end
|
||
end
|
||
end
|
||
return constrs, gmi_exps
|
||
end
|
||
|
||
function _dualgmi_features(h5_filename, extractor)
|
||
h5 = H5File(h5_filename, "r")
|
||
try
|
||
return extractor.get_instance_features(h5)
|
||
finally
|
||
h5.close()
|
||
end
|
||
end
|
||
|
||
function _dualgmi_compress_h5(h5_filename)
|
||
vars_to_basis_offset = Dict()
|
||
basis_vars = []
|
||
basis_sizes = []
|
||
cut_basis::Array{Int} = []
|
||
cut_row::Array{Int} = []
|
||
|
||
h5 = H5File(h5_filename, "r")
|
||
orig_cut_basis_vars = h5.get_array("cuts_basis_vars")
|
||
orig_cut_basis_sizes = h5.get_array("cuts_basis_sizes")
|
||
orig_cut_rows = h5.get_array("cuts_rows")
|
||
h5.close()
|
||
if orig_cut_basis_vars === nothing
|
||
@warn "orig_cut_basis_vars is null; skipping file"
|
||
return
|
||
end
|
||
ncuts, _ = size(orig_cut_basis_vars)
|
||
if ncuts == 0
|
||
return
|
||
end
|
||
|
||
for i in 1:ncuts
|
||
vars = orig_cut_basis_vars[i, :]
|
||
sizes = orig_cut_basis_sizes[i, :]
|
||
row = orig_cut_rows[i]
|
||
if vars ∉ keys(vars_to_basis_offset)
|
||
offset = size(basis_vars)[1] + 1
|
||
vars_to_basis_offset[vars] = offset
|
||
push!(basis_vars, vars)
|
||
push!(basis_sizes, sizes)
|
||
end
|
||
offset = vars_to_basis_offset[vars]
|
||
push!(cut_basis, offset)
|
||
push!(cut_row, row)
|
||
end
|
||
|
||
basis_vars = hcat(basis_vars...)'
|
||
basis_sizes = hcat(basis_sizes...)'
|
||
_, n_vars = size(basis_vars)
|
||
if n_vars == 0
|
||
@warn "n_vars is zero; skipping file"
|
||
return
|
||
end
|
||
|
||
h5 = H5File(h5_filename, "r+")
|
||
h5.put_array("gmi_basis_vars", basis_vars)
|
||
h5.put_array("gmi_basis_sizes", basis_sizes)
|
||
h5.put_array("gmi_cut_basis", cut_basis)
|
||
h5.put_array("gmi_cut_row", cut_row)
|
||
h5.file.close()
|
||
end
|
||
|
||
function _dualgmi_generate(train_h5, model; test_h5=nothing)
|
||
@timeit "Read problem data" begin
|
||
data = ProblemData(model)
|
||
end
|
||
@timeit "Convert to standard form" begin
|
||
data_s, transforms = convert_to_standard_form(data)
|
||
end
|
||
@timeit "Read optimal solution from test H5" begin
|
||
sol_opt_dict = nothing
|
||
sol_opt = nothing
|
||
sol_opt_s = nothing
|
||
if test_h5 !== nothing
|
||
try
|
||
h5 = H5File(test_h5, "r")
|
||
var_names = h5.get_array("static_var_names")
|
||
var_values = h5.get_array("mip_var_values")
|
||
h5.close()
|
||
if var_names !== nothing && var_values !== nothing
|
||
sol_opt_dict = Dict(zip(var_names, convert(Array{Float64}, var_values)))
|
||
sol_opt = [sol_opt_dict[n] for n in data.var_names]
|
||
sol_opt_s = forward(transforms, sol_opt)
|
||
@info "Loaded optimal solution for cut validation"
|
||
end
|
||
catch e
|
||
@warn "Could not read optimal solution from test H5 file: $e"
|
||
end
|
||
end
|
||
end
|
||
@timeit "Collect cuts from H5 files" begin
|
||
basis_vars_to_basis_offset = Dict()
|
||
combined_basis_sizes = nothing
|
||
combined_basis_sizes_list = Any[]
|
||
combined_basis_vars = nothing
|
||
combined_basis_vars_list = Any[]
|
||
combined_cut_rows = Any[]
|
||
for h5_filename in train_h5
|
||
@timeit "get_array (new)" begin
|
||
h5 = H5File(h5_filename, "r")
|
||
gmi_basis_vars = h5.get_array("gmi_basis_vars")
|
||
if gmi_basis_vars === nothing
|
||
@warn "$(h5_filename) does not contain gmi_basis_vars; skipping"
|
||
continue
|
||
end
|
||
gmi_basis_sizes = h5.get_array("gmi_basis_sizes")
|
||
gmi_cut_basis = h5.get_array("gmi_cut_basis")
|
||
gmi_cut_row = h5.get_array("gmi_cut_row")
|
||
h5.close()
|
||
end
|
||
@timeit "combine basis" begin
|
||
nbasis, _ = size(gmi_basis_vars)
|
||
local_to_combined_offset = Dict()
|
||
for local_offset in 1:nbasis
|
||
vars = gmi_basis_vars[local_offset, :]
|
||
sizes = gmi_basis_sizes[local_offset, :]
|
||
if vars ∉ keys(basis_vars_to_basis_offset)
|
||
combined_offset = length(combined_basis_vars_list) + 1
|
||
basis_vars_to_basis_offset[vars] = combined_offset
|
||
push!(combined_basis_vars_list, vars)
|
||
push!(combined_basis_sizes_list, sizes)
|
||
push!(combined_cut_rows, Set{Int}())
|
||
end
|
||
combined_offset = basis_vars_to_basis_offset[vars]
|
||
local_to_combined_offset[local_offset] = combined_offset
|
||
end
|
||
end
|
||
@timeit "combine rows" begin
|
||
ncuts = length(gmi_cut_row)
|
||
for i in 1:ncuts
|
||
local_offset = gmi_cut_basis[i]
|
||
combined_offset = local_to_combined_offset[local_offset]
|
||
row = gmi_cut_row[i]
|
||
push!(combined_cut_rows[combined_offset], row)
|
||
end
|
||
end
|
||
@timeit "convert lists to matrices" begin
|
||
combined_basis_vars = hcat(combined_basis_vars_list...)'
|
||
combined_basis_sizes = hcat(combined_basis_sizes_list...)'
|
||
end
|
||
end
|
||
end
|
||
@timeit "Compute tableaus and cuts" begin
|
||
all_cuts = nothing
|
||
nbasis = length(combined_cut_rows)
|
||
for offset in 1:nbasis
|
||
rows = combined_cut_rows[offset]
|
||
try
|
||
vbb, vnn, cbb, cnn = combined_basis_sizes[offset, :]
|
||
current_basis = Basis(;
|
||
var_basic = combined_basis_vars[offset, 1:vbb],
|
||
var_nonbasic = combined_basis_vars[offset, vbb+1:vbb+vnn],
|
||
constr_basic = combined_basis_vars[offset, vbb+vnn+1:vbb+vnn+cbb],
|
||
constr_nonbasic = combined_basis_vars[offset, vbb+vnn+cbb+1:vbb+vnn+cbb+cnn],
|
||
)
|
||
tableau = compute_tableau(data_s, current_basis; rows=collect(rows))
|
||
cuts_s = compute_gmi(data_s, tableau)
|
||
cuts = backwards(transforms, cuts_s)
|
||
if sol_opt_s !== nothing && sol_opt !== nothing
|
||
assert_does_not_cut_off(cuts_s, sol_opt_s)
|
||
assert_does_not_cut_off(cuts, sol_opt)
|
||
end
|
||
if all_cuts === nothing
|
||
all_cuts = cuts
|
||
else
|
||
all_cuts.lhs = [all_cuts.lhs; cuts.lhs]
|
||
all_cuts.lb = [all_cuts.lb; cuts.lb]
|
||
all_cuts.ub = [all_cuts.ub; cuts.ub]
|
||
end
|
||
catch e
|
||
if isa(e, AssertionError)
|
||
@warn "Numerical error detected. Skipping cuts from current tableau."
|
||
continue
|
||
else
|
||
rethrow(e)
|
||
end
|
||
end
|
||
end
|
||
end
|
||
return all_cuts
|
||
end
|
||
|
||
function _dualgmi_set_callback(model, all_cuts)
|
||
function cut_callback(cb_data)
|
||
if all_cuts !== nothing
|
||
constrs = build_constraints(model, all_cuts)
|
||
@info "Enforcing $(length(constrs)) cuts..."
|
||
for c in constrs
|
||
MOI.submit(model, MOI.UserCut(cb_data), c)
|
||
end
|
||
all_cuts = nothing
|
||
end
|
||
end
|
||
set_attribute(model, MOI.UserCutCallback(), cut_callback)
|
||
end
|
||
|
||
function KnnDualGmiComponent_fit(data::_KnnDualGmiData, train_h5)
|
||
x = hcat([_dualgmi_features(filename, data.extractor) for filename in train_h5]...)'
|
||
model = pyimport("sklearn.neighbors").NearestNeighbors(n_neighbors = length(train_h5))
|
||
model.fit(x)
|
||
data.model = model
|
||
data.train_h5 = train_h5
|
||
end
|
||
|
||
|
||
function KnnDualGmiComponent_before_mip(data::_KnnDualGmiData, test_h5, model, _)
|
||
reset_timer!()
|
||
|
||
@timeit "Extract features" begin
|
||
x = _dualgmi_features(test_h5, data.extractor)
|
||
x = reshape(x, 1, length(x))
|
||
end
|
||
|
||
@timeit "Find neighbors" begin
|
||
neigh_dist, neigh_ind = data.model.kneighbors(x, return_distance = true)
|
||
neigh_ind = neigh_ind .+ 1
|
||
N = length(neigh_dist)
|
||
k = min(N, data.k)
|
||
|
||
if data.strategy == "near"
|
||
selected = collect(1:k)
|
||
elseif data.strategy == "far"
|
||
selected = collect((N - k + 1) : N)
|
||
elseif data.strategy == "rand"
|
||
selected = shuffle(collect(1:N))[1:k]
|
||
else
|
||
error("unknown strategy: $(data.strategy)")
|
||
end
|
||
|
||
@info "Dual GMI: Selected neighbors ($(data.strategy)):"
|
||
neigh_dist = neigh_dist[selected]
|
||
neigh_ind = neigh_ind[selected]
|
||
for i in 1:k
|
||
h5_filename = data.train_h5[neigh_ind[i]]
|
||
dist = neigh_dist[i]
|
||
@info " $(h5_filename) dist=$(dist)"
|
||
end
|
||
end
|
||
|
||
@info "Dual GMI: Generating cuts..."
|
||
@timeit "Generate cuts" begin
|
||
time_generate = @elapsed begin
|
||
cuts = _dualgmi_generate(data.train_h5[neigh_ind], model)
|
||
end
|
||
@info "Dual GMI: Generated $(length(cuts.lb)) unique cuts in $(time_generate) seconds"
|
||
end
|
||
|
||
@timeit "Set callback" begin
|
||
_dualgmi_set_callback(model, cuts)
|
||
end
|
||
|
||
print_timer()
|
||
|
||
stats = Dict()
|
||
stats["KnnDualGmi: k"] = k
|
||
stats["KnnDualGmi: strategy"] = data.strategy
|
||
stats["KnnDualGmi: cuts"] = length(cuts.lb)
|
||
stats["KnnDualGmi: time generate"] = time_generate
|
||
return stats
|
||
end
|
||
|
||
function __init_gmi_dual__()
|
||
@pydef mutable struct KnnDualGmiComponentPy
|
||
function __init__(self; extractor, k = 3, strategy = "near")
|
||
self.data = _KnnDualGmiData(; extractor, k, strategy)
|
||
end
|
||
function fit(self, train_h5)
|
||
KnnDualGmiComponent_fit(self.data, train_h5)
|
||
end
|
||
function before_mip(self, test_h5, model, stats)
|
||
return @time KnnDualGmiComponent_before_mip(self.data, test_h5, model.inner, stats)
|
||
end
|
||
end
|
||
copy!(KnnDualGmiComponent, KnnDualGmiComponentPy)
|
||
|
||
@pydef mutable struct ExpertDualGmiComponentPy
|
||
function __init__(self)
|
||
self.inner = KnnDualGmiComponentPy(
|
||
extractor=H5FieldsExtractor(instance_fields=["static_var_obj_coeffs"]),
|
||
k=1,
|
||
)
|
||
end
|
||
function fit(self, train_h5)
|
||
end
|
||
function before_mip(self, test_h5, model, stats)
|
||
self.inner.fit([test_h5])
|
||
return self.inner.before_mip(test_h5, model, stats)
|
||
end
|
||
end
|
||
copy!(ExpertDualGmiComponent, ExpertDualGmiComponentPy)
|
||
end
|
||
|
||
export collect_gmi_dual, expert_gmi_dual, ExpertDualGmiComponent, KnnDualGmiComponent, collect_gmi_FisSal2011
|
||
|