BB: Collect strong branching data

master
Alinson S. Xavier 3 years ago
parent 75787090f4
commit 38c4e41720
Signed by: isoron
GPG Key ID: 0DA8E4B9E1109DCA

@ -21,6 +21,9 @@ global UserCutsComponent = PyNULL()
global MemorySample = PyNULL()
global Hdf5Sample = PyNULL()
to_str_array(values) = py"to_str_array"(values)
from_str_array(values) = py"from_str_array"(values)
include("solvers/structs.jl")
include("utils/log.jl")
@ -65,9 +68,6 @@ function __init__()
"""
end
to_str_array(values) = py"to_str_array"(values)
from_str_array(values) = py"from_str_array"(values)
function convert(::Type{SparseMatrixCSC}, o::PyObject)
I, J, V = pyimport("scipy.sparse").find(o)
return sparse(I .+ 1, J .+ 1, V, o.shape...)

@ -10,6 +10,7 @@ frac(x) = x - floor(x)
include("structs.jl")
include("collect.jl")
include("nodepool.jl")
include("optimize.jl")
include("log.jl")

@ -0,0 +1,61 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
using Printf
using Base.Threads
import Base.Threads: @threads, nthreads, threadid
import ..load_data, ..Hdf5Sample
function collect!(
optimizer,
filename::String;
time_limit::Float64 = Inf,
node_limit::Int = typemax(Int),
gap_limit::Float64 = 1e-4,
print_interval::Int = 5,
branch_rule::VariableBranchingRule = ReliabilityBranching(collect = true),
)::NodePool
model = read_from_file(filename)
mip = init(optimizer)
load!(mip, model)
h5 = Hdf5Sample(replace(filename, ".mps.gz" => ".h5"), "r")
primal_bound = h5.get_scalar("mip_upper_bound")
if primal_bound === nothing
primal_bound = h5.get_scalar("mip_obj_value")
end
h5.file.close()
pool = solve!(
mip;
initial_primal_bound = primal_bound,
time_limit,
node_limit,
gap_limit,
print_interval,
branch_rule,
)
h5 = Hdf5Sample(replace(filename, ".mps.gz" => ".h5"))
pseudocost_up = [NaN for _ = 1:mip.nvars]
pseudocost_down = [NaN for _ = 1:mip.nvars]
priorities = [0.0 for _ = 1:mip.nvars]
for (var, var_hist) in pool.var_history
pseudocost_up[var.index] = var_hist.pseudocost_up
pseudocost_down[var.index] = var_hist.pseudocost_down
x = mean(var_hist.fractional_values)
f_up = x - floor(x)
f_down = ceil(x) - x
priorities[var.index] =
var_hist.pseudocost_up * f_up * var_hist.pseudocost_down * f_down
end
h5.put_array("bb_var_pseudocost_up", pseudocost_up)
h5.put_array("bb_var_pseudocost_down", pseudocost_down)
h5.put_array("bb_var_priority", priorities)
collect!(branch_rule, h5)
h5.file.close()
return pool
end

@ -19,15 +19,7 @@ function _probe(
status = CPXlpopt(cpx.env, cpx.lp)
status == 0 || error("CPXlpopt failed ($status)")
status = CPXstrongbranch(
cpx.env,
cpx.lp,
indices,
cnt,
downobj,
upobj,
itlim,
)
status = CPXstrongbranch(cpx.env, cpx.lp, indices, cnt, downobj, upobj, itlim)
status == 0 || error("CPXstrongbranch failed ($status)")
return upobj[1] * mip.sense, downobj[1] * mip.sense

@ -11,13 +11,14 @@ const MOI = MathOptInterface
function init(constructor)::MIP
return MIP(
constructor,
Any[nothing for t = 1:nthreads()],
Variable[],
Float64[],
Float64[],
1.0,
0,
constructor = constructor,
optimizers = Any[nothing for t = 1:nthreads()],
int_vars = Variable[],
int_vars_lb = Float64[],
int_vars_ub = Float64[],
sense = 1.0,
lp_iterations = 0,
nvars = 0,
)
end
@ -27,10 +28,10 @@ function read!(mip::MIP, filename::AbstractString)::Nothing
end
function load!(mip::MIP, prototype::JuMP.Model)
mip.nvars = num_variables(prototype)
_replace_zero_one!(backend(prototype))
_assert_supported(backend(prototype))
mip.int_vars, mip.int_vars_lb, mip.int_vars_ub =
_get_int_variables(backend(prototype))
mip.int_vars, mip.int_vars_lb, mip.int_vars_ub = _get_int_variables(backend(prototype))
mip.sense = _get_objective_sense(backend(prototype))
_relax_integrality!(backend(prototype))
@threads for t = 1:nthreads()
@ -133,11 +134,7 @@ function _get_int_variables(
var_ub = constr.upper
MOI.delete(optimizer, _upper_bound_index(var))
end
MOI.add_constraint(
optimizer,
var,
MOI.Interval(var_lb, var_ub),
)
MOI.add_constraint(optimizer, var, MOI.Interval(var_lb, var_ub))
end
push!(vars, var)
push!(lb, var_lb)

@ -118,7 +118,7 @@ function _create_node(
parent::Union{Nothing,Node} = nothing,
branch_var::Union{Nothing,Variable} = nothing,
branch_var_lb::Union{Nothing,Float64} = nothing,
branch_var_ub::Union{Nothing,Float64}=nothing
branch_var_ub::Union{Nothing,Float64} = nothing,
)::Node
if parent === nothing
branch_vars = Variable[]
@ -135,8 +135,9 @@ function _create_node(
status, obj = solve_relaxation!(mip)
if status == :Optimal
vals = values(mip, mip.int_vars)
fractional_indices =
[j for j in 1:length(mip.int_vars) if 1e-6 < vals[j] - floor(vals[j]) < 1 - 1e-6]
fractional_indices = [
j for j in 1:length(mip.int_vars) if 1e-6 < vals[j] - floor(vals[j]) < 1 - 1e-6
]
fractional_values = vals[fractional_indices]
fractional_variables = mip.int_vars[fractional_indices]
else
@ -159,51 +160,6 @@ function _create_node(
)
end
function solve!(
optimizer,
filename::String;
time_limit::Float64=Inf,
node_limit::Int=typemax(Int),
gap_limit::Float64=1e-4,
print_interval::Int=5,
branch_rule::VariableBranchingRule=ReliabilityBranching()
)::NodePool
model = read_from_file("$filename.mps.gz")
mip = init(optimizer)
load!(mip, model)
h5 = Hdf5Sample("$filename.h5")
primal_bound = h5.get_scalar("mip_obj_value")
nvars = length(h5.get_array("static_var_names"))
pool = solve!(
mip;
initial_primal_bound=primal_bound,
time_limit,
node_limit,
gap_limit,
print_interval,
branch_rule
)
pseudocost_up = [NaN for _ = 1:nvars]
pseudocost_down = [NaN for _ = 1:nvars]
priorities = [0.0 for _ in 1:nvars]
for (var, var_hist) in pool.var_history
pseudocost_up[var.index] = var_hist.pseudocost_up
pseudocost_down[var.index] = var_hist.pseudocost_down
x = mean(var_hist.fractional_values)
f_up = x - floor(x)
f_down = ceil(x) - x
priorities[var.index] = var_hist.pseudocost_up * f_up * var_hist.pseudocost_down * f_down
end
h5.put_array("bb_var_pseudocost_up", pseudocost_up)
h5.put_array("bb_var_pseudocost_down", pseudocost_down)
h5.put_array("bb_var_priority", priorities)
return pool
end
function _set_node_bounds(node::Node)
set_bounds!(node.mip, node.branch_vars, node.branch_lb, node.branch_ub)
end

@ -9,7 +9,7 @@ struct Variable
index::Any
end
mutable struct MIP
Base.@kwdef mutable struct MIP
constructor::Any
optimizers::Vector
int_vars::Vector{Variable}
@ -17,6 +17,7 @@ mutable struct MIP
int_vars_ub::Vector{Float64}
sense::Float64
lp_iterations::Int64
nvars::Int
end
struct Node

@ -2,6 +2,16 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import ..to_str_array
Base.@kwdef mutable struct ReliabilityBranchingStats
branched_count::Vector{Int} = []
num_strong_branch_calls = 0
score_var_names::Vector{String} = []
score_features::Vector{Vector{Float32}} = []
score_targets::Vector{Float32} = []
end
"""
ReliabilityBranching
@ -13,12 +23,14 @@ Base.@kwdef mutable struct ReliabilityBranching <: VariableBranchingRule
min_samples::Int = 8
max_sb_calls::Int = 100
look_ahead::Int = 10
n_sb_calls::Int = 0
side_effect::Bool = true
max_iterations::Int = 1_000_000
aggregation::Symbol = :prod
stats::ReliabilityBranchingStats = ReliabilityBranchingStats()
collect::Bool = false
end
function _strong_branch_score(;
node::Node,
pool::NodePool,
@ -28,7 +40,6 @@ function _strong_branch_score(;
max_iterations::Int,
aggregation::Symbol,
)::Tuple{Float64,Int}
# Find current variable lower and upper bounds
offset = findfirst(isequal(var), node.mip.int_vars)
var_lb = node.mip.int_vars_lb[offset]
@ -68,6 +79,14 @@ function find_branching_var(
node::Node,
pool::NodePool,
)::Variable
stats = rule.stats
# Initialize statistics
if isempty(stats.branched_count)
stats.branched_count = zeros(node.mip.nvars)
end
# Sort variables by pseudocost score
nfrac = length(node.fractional_variables)
pseudocost_scores = [
_pseudocost_score(
@ -79,10 +98,37 @@ function find_branching_var(
]
σ = sortperm(pseudocost_scores, rev = true)
sorted_vars = node.fractional_variables[σ]
if rule.collect
# Compute dynamic features for all fractional variables
features = []
for (i, var) in enumerate(sorted_vars)
branched_count = stats.branched_count[var.index]
branched_count_rel = 0.0
branched_count_sum = sum(stats.branched_count[var.index])
if branched_count_sum > 0
branched_count_rel = branched_count / branched_count_sum
end
push!(
features,
Float32[
nfrac,
node.fractional_values[σ[i]],
node.depth,
pseudocost_scores[σ[i]][1],
branched_count,
branched_count_rel,
],
)
end
end
_set_node_bounds(node)
no_improv_count, n_sb_calls = 0, 0
max_score, max_var = pseudocost_scores[σ[1]], sorted_vars[1]
max_score, max_var = (-Inf, -Inf), sorted_vars[1]
for (i, var) in enumerate(sorted_vars)
# Decide whether to use strong branching
use_strong_branch = true
if n_sb_calls >= rule.max_sb_calls
use_strong_branch = false
@ -95,9 +141,10 @@ function find_branching_var(
end
end
end
if use_strong_branch
# Compute strong branching score
n_sb_calls += 1
rule.n_sb_calls += 1
score = _strong_branch_score(
node = node,
pool = pool,
@ -107,6 +154,13 @@ function find_branching_var(
max_iterations = rule.max_iterations,
aggregation = rule.aggregation,
)
if rule.collect
# Store training data
push!(stats.score_var_names, name(node.mip, var))
push!(stats.score_features, features[i])
push!(stats.score_targets, score[1])
end
else
score = pseudocost_scores[σ[i]]
end
@ -119,5 +173,16 @@ function find_branching_var(
no_improv_count <= rule.look_ahead || break
end
_unset_node_bounds(node)
# Update statistics
stats.branched_count[max_var.index] += 1
stats.num_strong_branch_calls += n_sb_calls
return max_var
end
function collect!(rule::ReliabilityBranching, h5)
h5.put_array("bb_score_var_names", to_str_array(rule.stats.score_var_names))
h5.put_array("bb_score_features", vcat(rule.stats.score_features'...))
h5.put_array("bb_score_targets", rule.stats.score_targets)
end

@ -80,11 +80,7 @@ function _update_solution!(data::JuMPSolverData)
push!(data.reduced_costs, rc)
# Basis status
data.basis_status[var] = MOI.get(
data.model,
MOI.VariableBasisStatus(),
var,
)
data.basis_status[var] = MOI.get(data.model, MOI.VariableBasisStatus(), var)
end
try

@ -70,7 +70,8 @@ function runtests(optimizer_name, optimizer; large = true)
end
@testset "varbranch" begin
branch_rules = [
for instance in ["bell5", "vpm2"]
for branch_rule in [
BB.RandomBranching(),
BB.FirstInfeasibleBranching(),
BB.LeastInfeasibleBranching(),
@ -80,15 +81,14 @@ function runtests(optimizer_name, optimizer; large = true)
BB.ReliabilityBranching(),
BB.HybridBranching(),
BB.StrongBranching(aggregation = :min),
BB.ReliabilityBranching(aggregation=:min),
BB.ReliabilityBranching(aggregation = :min, collect = true),
]
for branch_rule in branch_rules
for instance in ["bell5", "vpm2"]
h5 = Hdf5Sample("$basepath/../fixtures/$instance.h5")
mip_lower_bound = h5.get_scalar("mip_lower_bound")
mip_upper_bound = h5.get_scalar("mip_upper_bound")
mip_sense = h5.get_scalar("mip_sense")
mip_primal_bound = mip_sense == "min" ? mip_upper_bound : mip_lower_bound
mip_primal_bound =
mip_sense == "min" ? mip_upper_bound : mip_lower_bound
h5.file.close()
mip = BB.init(optimizer)
@ -104,25 +104,35 @@ function runtests(optimizer_name, optimizer; large = true)
end
end
end
@testset "collect" begin
rule = BB.ReliabilityBranching(collect = true)
BB.collect!(
optimizer,
"$basepath/../fixtures/bell5.mps.gz",
node_limit = 100,
print_interval = 10,
branch_rule = rule,
)
n_sb = rule.stats.num_strong_branch_calls
h5 = Hdf5Sample("$basepath/../fixtures/bell5.h5")
@test size(h5.get_array("bb_var_pseudocost_up")) == (104,)
@test size(h5.get_array("bb_score_var_names")) == (n_sb,)
@test size(h5.get_array("bb_score_features")) == (n_sb, 6)
@test size(h5.get_array("bb_score_targets")) == (n_sb,)
h5.file.close()
end
end
end
@testset "BB" begin
@time runtests(
"Clp",
optimizer_with_attributes(
Clp.Optimizer,
),
)
@time runtests("Clp", optimizer_with_attributes(Clp.Optimizer))
if is_gurobi_available
using Gurobi
@time runtests(
"Gurobi",
optimizer_with_attributes(
Gurobi.Optimizer,
"Threads" => 1,
)
optimizer_with_attributes(Gurobi.Optimizer, "Threads" => 1),
)
end
@ -130,10 +140,7 @@ end
using CPLEX
@time runtests(
"CPLEX",
optimizer_with_attributes(
CPLEX.Optimizer,
"CPXPARAM_Threads" => 1,
),
optimizer_with_attributes(CPLEX.Optimizer, "CPXPARAM_Threads" => 1),
)
end
end

Binary file not shown.
Loading…
Cancel
Save