mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 08:28:52 -06:00
BB: Remove duplication in SB/RB
This commit is contained in:
@@ -19,6 +19,50 @@ Base.@kwdef mutable struct ReliabilityBranching <: VariableBranchingRule
|
|||||||
aggregation::Symbol = :prod
|
aggregation::Symbol = :prod
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function _strong_branch_score(;
|
||||||
|
node::Node,
|
||||||
|
pool::NodePool,
|
||||||
|
var::Variable,
|
||||||
|
x::Float64,
|
||||||
|
side_effect::Bool,
|
||||||
|
max_iterations::Int,
|
||||||
|
aggregation::Symbol,
|
||||||
|
)::Tuple{Float64,Int}
|
||||||
|
|
||||||
|
# Find current variable lower and upper bounds
|
||||||
|
offset = findfirst(isequal(var), node.mip.int_vars)
|
||||||
|
var_lb = node.mip.int_vars_lb[offset]
|
||||||
|
var_ub = node.mip.int_vars_ub[offset]
|
||||||
|
for (offset, v) in enumerate(node.branch_vars)
|
||||||
|
if v == var
|
||||||
|
var_lb = max(var_lb, node.branch_lb[offset])
|
||||||
|
var_ub = min(var_ub, node.branch_ub[offset])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
obj_up, obj_down = 0, 0
|
||||||
|
obj_up, obj_down = probe(node.mip, var, x, var_lb, var_ub, max_iterations)
|
||||||
|
obj_change_up = obj_up - node.obj
|
||||||
|
obj_change_down = obj_down - node.obj
|
||||||
|
if side_effect
|
||||||
|
_update_var_history(
|
||||||
|
pool = pool,
|
||||||
|
var = var,
|
||||||
|
x = x,
|
||||||
|
obj_change_down = obj_change_down,
|
||||||
|
obj_change_up = obj_change_up,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
if aggregation == :prod
|
||||||
|
return (obj_change_up * obj_change_down, var.index)
|
||||||
|
elseif aggregation == :min
|
||||||
|
sense = node.mip.sense
|
||||||
|
return (min(sense * obj_up, sense * obj_down), var.index)
|
||||||
|
else
|
||||||
|
error("Unknown aggregation: $aggregation")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function find_branching_var(
|
function find_branching_var(
|
||||||
rule::ReliabilityBranching,
|
rule::ReliabilityBranching,
|
||||||
node::Node,
|
node::Node,
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ using Random
|
|||||||
Branching strategy that selects a subset of fractional variables
|
Branching strategy that selects a subset of fractional variables
|
||||||
as candidates (according to pseudocosts) the solves two linear
|
as candidates (according to pseudocosts) the solves two linear
|
||||||
programming problems for each candidate.
|
programming problems for each candidate.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Base.@kwdef struct StrongBranching <: VariableBranchingRule
|
Base.@kwdef struct StrongBranching <: VariableBranchingRule
|
||||||
look_ahead::Int = 10
|
look_ahead::Int = 10
|
||||||
@@ -22,85 +20,13 @@ Base.@kwdef struct StrongBranching <: VariableBranchingRule
|
|||||||
end
|
end
|
||||||
|
|
||||||
function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::Variable
|
function find_branching_var(rule::StrongBranching, node::Node, pool::NodePool)::Variable
|
||||||
nfrac = length(node.fractional_variables)
|
rb_rule = ReliabilityBranching(
|
||||||
pseudocost_scores = [
|
min_samples=typemax(Int),
|
||||||
_pseudocost_score(
|
max_sb_calls=rule.max_calls,
|
||||||
node,
|
look_ahead=rule.look_ahead,
|
||||||
pool,
|
side_effect=rule.side_effect,
|
||||||
node.fractional_variables[j],
|
max_iterations=rule.max_iterations,
|
||||||
node.fractional_values[j],
|
aggregation=rule.aggregation,
|
||||||
) for j = 1:nfrac
|
)
|
||||||
]
|
return find_branching_var(rb_rule, node, pool)
|
||||||
σ = sortperm(pseudocost_scores, rev = true)
|
|
||||||
sorted_vars = node.fractional_variables[σ]
|
|
||||||
_set_node_bounds(node)
|
|
||||||
no_improv_count, call_count = 0, 0
|
|
||||||
max_score, max_var = (-Inf, -Inf), sorted_vars[1]
|
|
||||||
for (i, var) in enumerate(sorted_vars)
|
|
||||||
call_count += 1
|
|
||||||
score = _strong_branch_score(
|
|
||||||
node = node,
|
|
||||||
pool = pool,
|
|
||||||
var = var,
|
|
||||||
x = node.fractional_values[σ[i]],
|
|
||||||
side_effect = rule.side_effect,
|
|
||||||
max_iterations = rule.max_iterations,
|
|
||||||
aggregation = rule.aggregation,
|
|
||||||
)
|
|
||||||
# @show name(node.mip, var), round(score[1], digits=2)
|
|
||||||
if score > max_score
|
|
||||||
max_score, max_var = score, var
|
|
||||||
no_improv_count = 0
|
|
||||||
else
|
|
||||||
no_improv_count += 1
|
|
||||||
end
|
|
||||||
no_improv_count <= rule.look_ahead || break
|
|
||||||
call_count <= rule.max_calls || break
|
|
||||||
end
|
|
||||||
_unset_node_bounds(node)
|
|
||||||
return max_var
|
|
||||||
end
|
|
||||||
|
|
||||||
function _strong_branch_score(;
|
|
||||||
node::Node,
|
|
||||||
pool::NodePool,
|
|
||||||
var::Variable,
|
|
||||||
x::Float64,
|
|
||||||
side_effect::Bool,
|
|
||||||
max_iterations::Int,
|
|
||||||
aggregation::Symbol,
|
|
||||||
)::Tuple{Float64,Int}
|
|
||||||
|
|
||||||
# Find current variable lower and upper bounds
|
|
||||||
offset = findfirst(isequal(var), node.mip.int_vars)
|
|
||||||
var_lb = node.mip.int_vars_lb[offset]
|
|
||||||
var_ub = node.mip.int_vars_ub[offset]
|
|
||||||
for (offset, v) in enumerate(node.branch_vars)
|
|
||||||
if v == var
|
|
||||||
var_lb = max(var_lb, node.branch_lb[offset])
|
|
||||||
var_ub = min(var_ub, node.branch_ub[offset])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
obj_up, obj_down = 0, 0
|
|
||||||
obj_up, obj_down = probe(node.mip, var, x, var_lb, var_ub, max_iterations)
|
|
||||||
obj_change_up = obj_up - node.obj
|
|
||||||
obj_change_down = obj_down - node.obj
|
|
||||||
if side_effect
|
|
||||||
_update_var_history(
|
|
||||||
pool = pool,
|
|
||||||
var = var,
|
|
||||||
x = x,
|
|
||||||
obj_change_down = obj_change_down,
|
|
||||||
obj_change_up = obj_change_up,
|
|
||||||
)
|
|
||||||
end
|
|
||||||
if aggregation == :prod
|
|
||||||
return (obj_change_up * obj_change_down, var.index)
|
|
||||||
elseif aggregation == :min
|
|
||||||
sense = node.mip.sense
|
|
||||||
return (min(sense * obj_up, sense * obj_down), var.index)
|
|
||||||
else
|
|
||||||
error("Unknown aggregation: $aggregation")
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|||||||
Reference in New Issue
Block a user