diff --git a/src/bb/varbranch/reliability.jl b/src/bb/varbranch/reliability.jl index 07a262d..e2b92c3 100644 --- a/src/bb/varbranch/reliability.jl +++ b/src/bb/varbranch/reliability.jl @@ -16,6 +16,7 @@ Base.@kwdef mutable struct ReliabilityBranching <: VariableBranchingRule n_sb_calls::Int = 0 side_effect::Bool = true max_iterations::Int = 1_000_000 + aggregation::Symbol = :prod end function find_branching_var( @@ -59,7 +60,8 @@ function find_branching_var( var = var, x = node.fractional_values[σ[i]], side_effect = rule.side_effect, - max_iterations = rule.max_iterations + max_iterations = rule.max_iterations, + aggregation = rule.aggregation, ) else score = pseudocost_scores[σ[i]] diff --git a/src/bb/varbranch/strong.jl b/src/bb/varbranch/strong.jl index 360c2f2..58e19f0 100644 --- a/src/bb/varbranch/strong.jl +++ b/src/bb/varbranch/strong.jl @@ -18,6 +18,7 @@ Base.@kwdef struct StrongBranching <: VariableBranchingRule max_calls::Int = 100 side_effect::Bool = true max_iterations::Int = 1_000_000 + aggregation::Symbol = :prod end function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::Variable @@ -44,6 +45,7 @@ function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool):: x = node.fractional_values[σ[i]], side_effect = rule.side_effect, max_iterations = rule.max_iterations, + aggregation = rule.aggregation, ) # @show name(node.mip, var), round(score[1], digits=2) if score > max_score @@ -66,6 +68,7 @@ function _strong_branch_score(; x::Float64, side_effect::Bool, max_iterations::Int, + aggregation::Symbol, )::Tuple{Float64,Int} # Find current variable lower and upper bounds @@ -92,5 +95,12 @@ function _strong_branch_score(; obj_change_up = obj_change_up, ) end - return (obj_change_up * obj_change_down, var.index) + if aggregation == :prod + return (obj_change_up * obj_change_down, var.index) + elseif aggregation == :min + sense = node.mip.sense + return (min(sense * obj_up, sense * obj_down), var.index) + else + error("Unknown aggregation: $aggregation") + end end diff --git a/test/bb/lp_test.jl b/test/bb/lp_test.jl index 3dad52f..fdb1402 100644 --- a/test/bb/lp_test.jl +++ b/test/bb/lp_test.jl @@ -79,6 +79,8 @@ function runtests(optimizer_name, optimizer; large = true) BB.StrongBranching(), BB.ReliabilityBranching(), BB.HybridBranching(), + BB.StrongBranching(aggregation=:min), + BB.ReliabilityBranching(aggregation=:min), ] for branch_rule in branch_rules for instance in ["bell5", "vpm2"]