mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 08:28:52 -06:00
Add types to JuMPSolverData
This commit is contained in:
@@ -9,15 +9,15 @@ using TimerOutputs
|
||||
|
||||
|
||||
mutable struct JuMPSolverData
|
||||
varname_to_var
|
||||
optimizer
|
||||
instance
|
||||
model
|
||||
bin_vars
|
||||
solution
|
||||
cname_to_constr
|
||||
reduced_costs
|
||||
dual_values
|
||||
optimizer_factory
|
||||
varname_to_var::Dict{AbstractString,VariableRef}
|
||||
cname_to_constr::Dict{AbstractString,JuMP.ConstraintRef}
|
||||
instance::Union{Nothing,PyObject}
|
||||
model::Union{Nothing,JuMP.Model}
|
||||
bin_vars::Vector{JuMP.VariableRef}
|
||||
solution::Vector{Float64}
|
||||
reduced_costs::Vector{Float64}
|
||||
dual_values::Dict{JuMP.ConstraintRef,Float64}
|
||||
end
|
||||
|
||||
|
||||
@@ -74,8 +74,8 @@ function _update_solution!(data::JuMPSolverData)
|
||||
end
|
||||
end
|
||||
else
|
||||
data.reduced_costs = nothing
|
||||
data.dual_values = nothing
|
||||
data.reduced_costs = []
|
||||
data.dual_values = Dict()
|
||||
end
|
||||
end
|
||||
|
||||
@@ -85,7 +85,7 @@ function solve(
|
||||
tee::Bool=false,
|
||||
iteration_cb=nothing,
|
||||
)
|
||||
instance, model = data.instance, data.model
|
||||
model = data.model
|
||||
wallclock_time = 0
|
||||
log = ""
|
||||
while true
|
||||
@@ -144,7 +144,7 @@ function solve_lp(data::JuMPSolverData; tee::Bool=false)
|
||||
end
|
||||
|
||||
|
||||
function set_instance!(data::JuMPSolverData, instance, model)
|
||||
function set_instance!(data::JuMPSolverData, instance, model::JuMP.Model)
|
||||
data.instance = instance
|
||||
data.model = model
|
||||
data.bin_vars = [
|
||||
@@ -156,9 +156,7 @@ function set_instance!(data::JuMPSolverData, instance, model)
|
||||
JuMP.name(var) => var
|
||||
for var in JuMP.all_variables(data.model)
|
||||
)
|
||||
if data.optimizer !== nothing
|
||||
JuMP.set_optimizer(model, data.optimizer)
|
||||
end
|
||||
JuMP.set_optimizer(model, data.optimizer_factory)
|
||||
data.cname_to_constr = Dict()
|
||||
for (ftype, stype) in JuMP.list_of_constraint_types(model)
|
||||
for constr in JuMP.all_constraints(model, ftype, stype)
|
||||
@@ -237,8 +235,8 @@ function get_variables(
|
||||
)
|
||||
end
|
||||
|
||||
rc = data.reduced_costs === nothing ? nothing : Tuple(data.reduced_costs)
|
||||
values = data.solution === nothing ? nothing : Tuple(data.solution)
|
||||
rc = isempty(data.reduced_costs) ? nothing : Tuple(data.reduced_costs)
|
||||
values = isempty(data.solution) ? nothing : Tuple(data.solution)
|
||||
|
||||
return miplearn.features.VariableFeatures(
|
||||
names=names,
|
||||
@@ -260,7 +258,7 @@ function get_constraints(
|
||||
senses, lhs, rhs = nothing, nothing, nothing
|
||||
dual_values = nothing
|
||||
|
||||
if data.dual_values !== nothing
|
||||
if !isempty(data.dual_values)
|
||||
dual_values = []
|
||||
end
|
||||
|
||||
@@ -280,7 +278,7 @@ function get_constraints(
|
||||
length(name) > 0 || continue
|
||||
push!(names, name)
|
||||
|
||||
if data.dual_values !== nothing
|
||||
if !isempty(data.dual_values)
|
||||
push!(dual_values, data.dual_values[constr])
|
||||
end
|
||||
|
||||
@@ -355,17 +353,17 @@ end
|
||||
|
||||
|
||||
@pydef mutable struct JuMPSolver <: miplearn.solvers.internal.InternalSolver
|
||||
function __init__(self; optimizer)
|
||||
function __init__(self, optimizer_factory)
|
||||
self.data = JuMPSolverData(
|
||||
nothing, # varname_to_var
|
||||
optimizer,
|
||||
optimizer_factory,
|
||||
Dict(), # varname_to_var
|
||||
Dict(), # cname_to_constr
|
||||
nothing, # instance
|
||||
nothing, # model
|
||||
nothing, # bin_vars
|
||||
nothing, # solution
|
||||
nothing, # cname_to_constr
|
||||
nothing, # reduced_costs
|
||||
nothing, # dual_values
|
||||
[], # bin_vars
|
||||
[], # solution
|
||||
[], # reduced_costs
|
||||
Dict(), # dual_values
|
||||
)
|
||||
end
|
||||
|
||||
@@ -466,4 +464,4 @@ end
|
||||
end
|
||||
|
||||
|
||||
export JuMPSolver, solve!, fit!, add!
|
||||
export JuMPSolver
|
||||
|
||||
@@ -12,7 +12,7 @@ miplearn_tests = pyimport("miplearn.solvers.tests")
|
||||
traceback = pyimport("traceback")
|
||||
|
||||
@testset "JuMPSolver" begin
|
||||
solver = JuMPSolver(optimizer=Gurobi.Optimizer)
|
||||
solver = JuMPSolver(Gurobi.Optimizer)
|
||||
try
|
||||
miplearn_tests.run_internal_solver_tests(solver)
|
||||
catch e
|
||||
|
||||
Reference in New Issue
Block a user