mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 08:28:52 -06:00
BB: Make aggregation configurable
This commit is contained in:
@@ -16,6 +16,7 @@ Base.@kwdef mutable struct ReliabilityBranching <: VariableBranchingRule
|
|||||||
n_sb_calls::Int = 0
|
n_sb_calls::Int = 0
|
||||||
side_effect::Bool = true
|
side_effect::Bool = true
|
||||||
max_iterations::Int = 1_000_000
|
max_iterations::Int = 1_000_000
|
||||||
|
aggregation::Symbol = :prod
|
||||||
end
|
end
|
||||||
|
|
||||||
function find_branching_var(
|
function find_branching_var(
|
||||||
@@ -59,7 +60,8 @@ function find_branching_var(
|
|||||||
var = var,
|
var = var,
|
||||||
x = node.fractional_values[σ[i]],
|
x = node.fractional_values[σ[i]],
|
||||||
side_effect = rule.side_effect,
|
side_effect = rule.side_effect,
|
||||||
max_iterations = rule.max_iterations
|
max_iterations = rule.max_iterations,
|
||||||
|
aggregation = rule.aggregation,
|
||||||
)
|
)
|
||||||
else
|
else
|
||||||
score = pseudocost_scores[σ[i]]
|
score = pseudocost_scores[σ[i]]
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Base.@kwdef struct StrongBranching <: VariableBranchingRule
|
|||||||
max_calls::Int = 100
|
max_calls::Int = 100
|
||||||
side_effect::Bool = true
|
side_effect::Bool = true
|
||||||
max_iterations::Int = 1_000_000
|
max_iterations::Int = 1_000_000
|
||||||
|
aggregation::Symbol = :prod
|
||||||
end
|
end
|
||||||
|
|
||||||
function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::Variable
|
function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::Variable
|
||||||
@@ -44,6 +45,7 @@ function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::
|
|||||||
x = node.fractional_values[σ[i]],
|
x = node.fractional_values[σ[i]],
|
||||||
side_effect = rule.side_effect,
|
side_effect = rule.side_effect,
|
||||||
max_iterations = rule.max_iterations,
|
max_iterations = rule.max_iterations,
|
||||||
|
aggregation = rule.aggregation,
|
||||||
)
|
)
|
||||||
# @show name(node.mip, var), round(score[1], digits=2)
|
# @show name(node.mip, var), round(score[1], digits=2)
|
||||||
if score > max_score
|
if score > max_score
|
||||||
@@ -66,6 +68,7 @@ function _strong_branch_score(;
|
|||||||
x::Float64,
|
x::Float64,
|
||||||
side_effect::Bool,
|
side_effect::Bool,
|
||||||
max_iterations::Int,
|
max_iterations::Int,
|
||||||
|
aggregation::Symbol,
|
||||||
)::Tuple{Float64,Int}
|
)::Tuple{Float64,Int}
|
||||||
|
|
||||||
# Find current variable lower and upper bounds
|
# Find current variable lower and upper bounds
|
||||||
@@ -92,5 +95,12 @@ function _strong_branch_score(;
|
|||||||
obj_change_up = obj_change_up,
|
obj_change_up = obj_change_up,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
if aggregation == :prod
|
||||||
return (obj_change_up * obj_change_down, var.index)
|
return (obj_change_up * obj_change_down, var.index)
|
||||||
|
elseif aggregation == :min
|
||||||
|
sense = node.mip.sense
|
||||||
|
return (min(sense * obj_up, sense * obj_down), var.index)
|
||||||
|
else
|
||||||
|
error("Unknown aggregation: $aggregation")
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ function runtests(optimizer_name, optimizer; large = true)
|
|||||||
BB.StrongBranching(),
|
BB.StrongBranching(),
|
||||||
BB.ReliabilityBranching(),
|
BB.ReliabilityBranching(),
|
||||||
BB.HybridBranching(),
|
BB.HybridBranching(),
|
||||||
|
BB.StrongBranching(aggregation=:min),
|
||||||
|
BB.ReliabilityBranching(aggregation=:min),
|
||||||
]
|
]
|
||||||
for branch_rule in branch_rules
|
for branch_rule in branch_rules
|
||||||
for instance in ["bell5", "vpm2"]
|
for instance in ["bell5", "vpm2"]
|
||||||
|
|||||||
Reference in New Issue
Block a user