You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
4.9 KiB
154 lines
4.9 KiB
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
|
|
# Copyright (C) 2020-2023, UChicago Argonne, LLC. All rights reserved.
|
|
# Released under the modified BSD license. See COPYING.md for more details.
|
|
|
|
using KLU
|
|
using TimerOutputs
|
|
|
|
function get_basis(model::JuMP.Model)::Basis
|
|
var_basic = Int[]
|
|
var_nonbasic = Int[]
|
|
constr_basic = Int[]
|
|
constr_nonbasic = Int[]
|
|
|
|
# Variables
|
|
for (i, var) in enumerate(all_variables(model))
|
|
bstatus = MOI.get(model, MOI.VariableBasisStatus(), var)
|
|
if bstatus == MOI.BASIC
|
|
push!(var_basic, i)
|
|
elseif bstatus == MOI.NONBASIC_AT_LOWER
|
|
push!(var_nonbasic, i)
|
|
else
|
|
error("Unknown basis status: $bstatus")
|
|
end
|
|
end
|
|
|
|
# Constraints
|
|
constr_index = 1
|
|
for (ftype, stype) in list_of_constraint_types(model)
|
|
for constr in all_constraints(model, ftype, stype)
|
|
if ftype == VariableRef
|
|
# nop
|
|
elseif ftype == AffExpr
|
|
bstatus = MOI.get(model, MOI.ConstraintBasisStatus(), constr)
|
|
if bstatus == MOI.BASIC
|
|
push!(constr_basic, constr_index)
|
|
elseif bstatus == MOI.NONBASIC
|
|
push!(constr_nonbasic, constr_index)
|
|
else
|
|
error("Unknown basis status: $bstatus")
|
|
end
|
|
constr_index += 1
|
|
else
|
|
error("Unsupported constraint type: ($ftype, $stype)")
|
|
end
|
|
end
|
|
end
|
|
|
|
return Basis(; var_basic, var_nonbasic, constr_basic, constr_nonbasic)
|
|
end
|
|
|
|
function get_x(model::JuMP.Model)
|
|
return JuMP.value.(all_variables(model))
|
|
end
|
|
|
|
function compute_tableau(
|
|
data::ProblemData,
|
|
basis::Basis;
|
|
x::Union{Nothing,Vector{Float64}} = nothing,
|
|
rows::Union{Vector{Int},Nothing} = nothing,
|
|
tol = 1e-8,
|
|
estimated_density = 0.10,
|
|
)::Tableau
|
|
@timeit "Split data" begin
|
|
nrows, ncols = size(data.constr_lhs)
|
|
lhs_slacks = sparse(I, nrows, nrows)
|
|
lhs_b = [data.constr_lhs[:, basis.var_basic] lhs_slacks[:, basis.constr_basic]]
|
|
obj_b = [data.obj[basis.var_basic]; zeros(length(basis.constr_basic))]
|
|
if rows === nothing
|
|
rows = 1:nrows
|
|
end
|
|
end
|
|
|
|
@timeit "Factorize basis matrix" begin
|
|
factor = klu(sparse(lhs_b'))
|
|
end
|
|
|
|
@timeit "Initialize arrays" begin
|
|
num_rows = length(rows)
|
|
tableau_rhs::Array{Float64} = zeros(num_rows)
|
|
tableau_rowptr::Array{Int} = zeros(Int, num_rows + 1)
|
|
tableau_colval::Array{Int} = Int[]
|
|
tableau_nzval::Array{Float64} = Float64[]
|
|
estimated_nnz::Int = round(num_rows * ncols * estimated_density)
|
|
sizehint!(tableau_colval, estimated_nnz)
|
|
sizehint!(tableau_nzval, estimated_nnz)
|
|
e::Array{Float64} = zeros(nrows)
|
|
sol::Array{Float64} = zeros(nrows)
|
|
tableau_row::Array{Float64} = zeros(ncols)
|
|
end
|
|
|
|
A = data.constr_lhs'
|
|
b = data.constr_ub
|
|
tableau_rowptr[1] = 1
|
|
|
|
@timeit "Process rows" begin
|
|
for k in eachindex(rows)
|
|
@timeit "Solve" begin
|
|
fill!(e, 0.0)
|
|
e[rows[k]] = 1.0
|
|
ldiv!(sol, factor, e)
|
|
end
|
|
@timeit "Compute row" begin
|
|
mul!(tableau_row, A, sol)
|
|
tableau_rhs[k] = dot(sol, b)
|
|
end
|
|
needed_space = length(tableau_colval) + ncols
|
|
if needed_space > estimated_nnz
|
|
@timeit "Grow arrays" begin
|
|
estimated_nnz *= 2
|
|
sizehint!(tableau_colval, estimated_nnz)
|
|
sizehint!(tableau_nzval, estimated_nnz)
|
|
end
|
|
end
|
|
@timeit "Collect nonzeros for row" begin
|
|
for j in 1:ncols
|
|
val = tableau_row[j]
|
|
if abs(val) > tol
|
|
push!(tableau_colval, j)
|
|
push!(tableau_nzval, val)
|
|
end
|
|
end
|
|
end
|
|
tableau_rowptr[k + 1] = length(tableau_colval) + 1
|
|
end
|
|
end
|
|
|
|
@timeit "Shrink arrays" begin
|
|
sizehint!(tableau_colval, length(tableau_colval))
|
|
sizehint!(tableau_nzval, length(tableau_nzval))
|
|
end
|
|
|
|
@timeit "Build sparse matrix" begin
|
|
tableau_lhs_transposed = SparseMatrixCSC(ncols, num_rows, tableau_rowptr, tableau_colval, tableau_nzval)
|
|
tableau_lhs = transpose(tableau_lhs_transposed)
|
|
end
|
|
|
|
@timeit "Compute tableau objective row" begin
|
|
sol = factor \ obj_b
|
|
tableau_obj = -data.obj' + sol' * data.constr_lhs
|
|
tableau_obj[abs.(tableau_obj).<tol] .= 0
|
|
tableau_obj = Array(tableau_obj')
|
|
end
|
|
|
|
# Compute z if solution is provided
|
|
z = 0
|
|
if x !== nothing
|
|
z = dot(data.obj, x)
|
|
end
|
|
|
|
return Tableau(obj = tableau_obj, lhs = tableau_lhs, rhs = tableau_rhs, z = z)
|
|
end
|
|
|
|
export get_basis, get_x, compute_tableau
|