mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
BB: Make aggregation configurable
This commit is contained in:
@@ -16,6 +16,7 @@ Base.@kwdef mutable struct ReliabilityBranching <: VariableBranchingRule
|
||||
n_sb_calls::Int = 0
|
||||
side_effect::Bool = true
|
||||
max_iterations::Int = 1_000_000
|
||||
aggregation::Symbol = :prod
|
||||
end
|
||||
|
||||
function find_branching_var(
|
||||
@@ -59,7 +60,8 @@ function find_branching_var(
|
||||
var = var,
|
||||
x = node.fractional_values[σ[i]],
|
||||
side_effect = rule.side_effect,
|
||||
max_iterations = rule.max_iterations
|
||||
max_iterations = rule.max_iterations,
|
||||
aggregation = rule.aggregation,
|
||||
)
|
||||
else
|
||||
score = pseudocost_scores[σ[i]]
|
||||
|
||||
@@ -18,6 +18,7 @@ Base.@kwdef struct StrongBranching <: VariableBranchingRule
|
||||
max_calls::Int = 100
|
||||
side_effect::Bool = true
|
||||
max_iterations::Int = 1_000_000
|
||||
aggregation::Symbol = :prod
|
||||
end
|
||||
|
||||
function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::Variable
|
||||
@@ -44,6 +45,7 @@ function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::
|
||||
x = node.fractional_values[σ[i]],
|
||||
side_effect = rule.side_effect,
|
||||
max_iterations = rule.max_iterations,
|
||||
aggregation = rule.aggregation,
|
||||
)
|
||||
# @show name(node.mip, var), round(score[1], digits=2)
|
||||
if score > max_score
|
||||
@@ -66,6 +68,7 @@ function _strong_branch_score(;
|
||||
x::Float64,
|
||||
side_effect::Bool,
|
||||
max_iterations::Int,
|
||||
aggregation::Symbol,
|
||||
)::Tuple{Float64,Int}
|
||||
|
||||
# Find current variable lower and upper bounds
|
||||
@@ -92,5 +95,12 @@ function _strong_branch_score(;
|
||||
obj_change_up = obj_change_up,
|
||||
)
|
||||
end
|
||||
return (obj_change_up * obj_change_down, var.index)
|
||||
if aggregation == :prod
|
||||
return (obj_change_up * obj_change_down, var.index)
|
||||
elseif aggregation == :min
|
||||
sense = node.mip.sense
|
||||
return (min(sense * obj_up, sense * obj_down), var.index)
|
||||
else
|
||||
error("Unknown aggregation: $aggregation")
|
||||
end
|
||||
end
|
||||
|
||||
@@ -79,6 +79,8 @@ function runtests(optimizer_name, optimizer; large = true)
|
||||
BB.StrongBranching(),
|
||||
BB.ReliabilityBranching(),
|
||||
BB.HybridBranching(),
|
||||
BB.StrongBranching(aggregation=:min),
|
||||
BB.ReliabilityBranching(aggregation=:min),
|
||||
]
|
||||
for branch_rule in branch_rules
|
||||
for instance in ["bell5", "vpm2"]
|
||||
|
||||
Reference in New Issue
Block a user