Add types to JuMPSolverData

master
Alinson S. Xavier 4 years ago
parent b5c45966d3
commit c1c0eddb43

@ -9,15 +9,15 @@ using TimerOutputs
mutable struct JuMPSolverData
varname_to_var
optimizer
instance
model
bin_vars
solution
cname_to_constr
reduced_costs
dual_values
optimizer_factory
varname_to_var::Dict{AbstractString,VariableRef}
cname_to_constr::Dict{AbstractString,JuMP.ConstraintRef}
instance::Union{Nothing,PyObject}
model::Union{Nothing,JuMP.Model}
bin_vars::Vector{JuMP.VariableRef}
solution::Vector{Float64}
reduced_costs::Vector{Float64}
dual_values::Dict{JuMP.ConstraintRef,Float64}
end
@ -74,8 +74,8 @@ function _update_solution!(data::JuMPSolverData)
end
end
else
data.reduced_costs = nothing
data.dual_values = nothing
data.reduced_costs = []
data.dual_values = Dict()
end
end
@ -85,7 +85,7 @@ function solve(
tee::Bool=false,
iteration_cb=nothing,
)
instance, model = data.instance, data.model
model = data.model
wallclock_time = 0
log = ""
while true
@ -144,7 +144,7 @@ function solve_lp(data::JuMPSolverData; tee::Bool=false)
end
function set_instance!(data::JuMPSolverData, instance, model)
function set_instance!(data::JuMPSolverData, instance, model::JuMP.Model)
data.instance = instance
data.model = model
data.bin_vars = [
@ -156,9 +156,7 @@ function set_instance!(data::JuMPSolverData, instance, model)
JuMP.name(var) => var
for var in JuMP.all_variables(data.model)
)
if data.optimizer !== nothing
JuMP.set_optimizer(model, data.optimizer)
end
JuMP.set_optimizer(model, data.optimizer_factory)
data.cname_to_constr = Dict()
for (ftype, stype) in JuMP.list_of_constraint_types(model)
for constr in JuMP.all_constraints(model, ftype, stype)
@ -237,8 +235,8 @@ function get_variables(
)
end
rc = data.reduced_costs === nothing ? nothing : Tuple(data.reduced_costs)
values = data.solution === nothing ? nothing : Tuple(data.solution)
rc = isempty(data.reduced_costs) ? nothing : Tuple(data.reduced_costs)
values = isempty(data.solution) ? nothing : Tuple(data.solution)
return miplearn.features.VariableFeatures(
names=names,
@ -260,7 +258,7 @@ function get_constraints(
senses, lhs, rhs = nothing, nothing, nothing
dual_values = nothing
if data.dual_values !== nothing
if !isempty(data.dual_values)
dual_values = []
end
@ -280,7 +278,7 @@ function get_constraints(
length(name) > 0 || continue
push!(names, name)
if data.dual_values !== nothing
if !isempty(data.dual_values)
push!(dual_values, data.dual_values[constr])
end
@ -355,17 +353,17 @@ end
@pydef mutable struct JuMPSolver <: miplearn.solvers.internal.InternalSolver
function __init__(self; optimizer)
function __init__(self, optimizer_factory)
self.data = JuMPSolverData(
nothing, # varname_to_var
optimizer,
optimizer_factory,
Dict(), # varname_to_var
Dict(), # cname_to_constr
nothing, # instance
nothing, # model
nothing, # bin_vars
nothing, # solution
nothing, # cname_to_constr
nothing, # reduced_costs
nothing, # dual_values
[], # bin_vars
[], # solution
[], # reduced_costs
Dict(), # dual_values
)
end
@ -466,4 +464,4 @@ end
end
export JuMPSolver, solve!, fit!, add!
export JuMPSolver

@ -12,7 +12,7 @@ miplearn_tests = pyimport("miplearn.solvers.tests")
traceback = pyimport("traceback")
@testset "JuMPSolver" begin
solver = JuMPSolver(optimizer=Gurobi.Optimizer)
solver = JuMPSolver(Gurobi.Optimizer)
try
miplearn_tests.run_internal_solver_tests(solver)
catch e

Loading…
Cancel
Save