diff --git a/src/modeling/jump_solver.jl b/src/modeling/jump_solver.jl index bd642a9..9a3113a 100644 --- a/src/modeling/jump_solver.jl +++ b/src/modeling/jump_solver.jl @@ -9,15 +9,15 @@ using TimerOutputs mutable struct JuMPSolverData - varname_to_var - optimizer - instance - model - bin_vars - solution - cname_to_constr - reduced_costs - dual_values + optimizer_factory + varname_to_var::Dict{AbstractString,VariableRef} + cname_to_constr::Dict{AbstractString,JuMP.ConstraintRef} + instance::Union{Nothing,PyObject} + model::Union{Nothing,JuMP.Model} + bin_vars::Vector{JuMP.VariableRef} + solution::Vector{Float64} + reduced_costs::Vector{Float64} + dual_values::Dict{JuMP.ConstraintRef,Float64} end @@ -74,8 +74,8 @@ function _update_solution!(data::JuMPSolverData) end end else - data.reduced_costs = nothing - data.dual_values = nothing + data.reduced_costs = [] + data.dual_values = Dict() end end @@ -85,7 +85,7 @@ function solve( tee::Bool=false, iteration_cb=nothing, ) - instance, model = data.instance, data.model + model = data.model wallclock_time = 0 log = "" while true @@ -144,7 +144,7 @@ function solve_lp(data::JuMPSolverData; tee::Bool=false) end -function set_instance!(data::JuMPSolverData, instance, model) +function set_instance!(data::JuMPSolverData, instance, model::JuMP.Model) data.instance = instance data.model = model data.bin_vars = [ @@ -156,9 +156,7 @@ function set_instance!(data::JuMPSolverData, instance, model) JuMP.name(var) => var for var in JuMP.all_variables(data.model) ) - if data.optimizer !== nothing - JuMP.set_optimizer(model, data.optimizer) - end + JuMP.set_optimizer(model, data.optimizer_factory) data.cname_to_constr = Dict() for (ftype, stype) in JuMP.list_of_constraint_types(model) for constr in JuMP.all_constraints(model, ftype, stype) @@ -237,8 +235,8 @@ function get_variables( ) end - rc = data.reduced_costs === nothing ? nothing : Tuple(data.reduced_costs) - values = data.solution === nothing ? nothing : Tuple(data.solution) + rc = isempty(data.reduced_costs) ? nothing : Tuple(data.reduced_costs) + values = isempty(data.solution) ? nothing : Tuple(data.solution) return miplearn.features.VariableFeatures( names=names, @@ -260,7 +258,7 @@ function get_constraints( senses, lhs, rhs = nothing, nothing, nothing dual_values = nothing - if data.dual_values !== nothing + if !isempty(data.dual_values) dual_values = [] end @@ -280,7 +278,7 @@ function get_constraints( length(name) > 0 || continue push!(names, name) - if data.dual_values !== nothing + if !isempty(data.dual_values) push!(dual_values, data.dual_values[constr]) end @@ -355,18 +353,18 @@ end @pydef mutable struct JuMPSolver <: miplearn.solvers.internal.InternalSolver - function __init__(self; optimizer) + function __init__(self, optimizer_factory) self.data = JuMPSolverData( - nothing, # varname_to_var - optimizer, + optimizer_factory, + Dict(), # varname_to_var + Dict(), # cname_to_constr nothing, # instance nothing, # model - nothing, # bin_vars - nothing, # solution - nothing, # cname_to_constr - nothing, # reduced_costs - nothing, # dual_values - ) + [], # bin_vars + [], # solution + [], # reduced_costs + Dict(), # dual_values + ) end add_constraints(self, cf) = @@ -466,4 +464,4 @@ end end -export JuMPSolver, solve!, fit!, add! +export JuMPSolver diff --git a/test/modeling/jump_solver_test.jl b/test/modeling/jump_solver_test.jl index 96a8ddc..b9470c8 100644 --- a/test/modeling/jump_solver_test.jl +++ b/test/modeling/jump_solver_test.jl @@ -12,7 +12,7 @@ miplearn_tests = pyimport("miplearn.solvers.tests") traceback = pyimport("traceback") @testset "JuMPSolver" begin - solver = JuMPSolver(optimizer=Gurobi.Optimizer) + solver = JuMPSolver(Gurobi.Optimizer) try miplearn_tests.run_internal_solver_tests(solver) catch e