mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
Add types to JuMPSolverData
This commit is contained in:
@@ -9,15 +9,15 @@ using TimerOutputs
|
|||||||
|
|
||||||
|
|
||||||
mutable struct JuMPSolverData
|
mutable struct JuMPSolverData
|
||||||
varname_to_var
|
optimizer_factory
|
||||||
optimizer
|
varname_to_var::Dict{AbstractString,VariableRef}
|
||||||
instance
|
cname_to_constr::Dict{AbstractString,JuMP.ConstraintRef}
|
||||||
model
|
instance::Union{Nothing,PyObject}
|
||||||
bin_vars
|
model::Union{Nothing,JuMP.Model}
|
||||||
solution
|
bin_vars::Vector{JuMP.VariableRef}
|
||||||
cname_to_constr
|
solution::Vector{Float64}
|
||||||
reduced_costs
|
reduced_costs::Vector{Float64}
|
||||||
dual_values
|
dual_values::Dict{JuMP.ConstraintRef,Float64}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
@@ -74,8 +74,8 @@ function _update_solution!(data::JuMPSolverData)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
data.reduced_costs = nothing
|
data.reduced_costs = []
|
||||||
data.dual_values = nothing
|
data.dual_values = Dict()
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ function solve(
|
|||||||
tee::Bool=false,
|
tee::Bool=false,
|
||||||
iteration_cb=nothing,
|
iteration_cb=nothing,
|
||||||
)
|
)
|
||||||
instance, model = data.instance, data.model
|
model = data.model
|
||||||
wallclock_time = 0
|
wallclock_time = 0
|
||||||
log = ""
|
log = ""
|
||||||
while true
|
while true
|
||||||
@@ -144,7 +144,7 @@ function solve_lp(data::JuMPSolverData; tee::Bool=false)
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
function set_instance!(data::JuMPSolverData, instance, model)
|
function set_instance!(data::JuMPSolverData, instance, model::JuMP.Model)
|
||||||
data.instance = instance
|
data.instance = instance
|
||||||
data.model = model
|
data.model = model
|
||||||
data.bin_vars = [
|
data.bin_vars = [
|
||||||
@@ -156,9 +156,7 @@ function set_instance!(data::JuMPSolverData, instance, model)
|
|||||||
JuMP.name(var) => var
|
JuMP.name(var) => var
|
||||||
for var in JuMP.all_variables(data.model)
|
for var in JuMP.all_variables(data.model)
|
||||||
)
|
)
|
||||||
if data.optimizer !== nothing
|
JuMP.set_optimizer(model, data.optimizer_factory)
|
||||||
JuMP.set_optimizer(model, data.optimizer)
|
|
||||||
end
|
|
||||||
data.cname_to_constr = Dict()
|
data.cname_to_constr = Dict()
|
||||||
for (ftype, stype) in JuMP.list_of_constraint_types(model)
|
for (ftype, stype) in JuMP.list_of_constraint_types(model)
|
||||||
for constr in JuMP.all_constraints(model, ftype, stype)
|
for constr in JuMP.all_constraints(model, ftype, stype)
|
||||||
@@ -237,8 +235,8 @@ function get_variables(
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
rc = data.reduced_costs === nothing ? nothing : Tuple(data.reduced_costs)
|
rc = isempty(data.reduced_costs) ? nothing : Tuple(data.reduced_costs)
|
||||||
values = data.solution === nothing ? nothing : Tuple(data.solution)
|
values = isempty(data.solution) ? nothing : Tuple(data.solution)
|
||||||
|
|
||||||
return miplearn.features.VariableFeatures(
|
return miplearn.features.VariableFeatures(
|
||||||
names=names,
|
names=names,
|
||||||
@@ -260,7 +258,7 @@ function get_constraints(
|
|||||||
senses, lhs, rhs = nothing, nothing, nothing
|
senses, lhs, rhs = nothing, nothing, nothing
|
||||||
dual_values = nothing
|
dual_values = nothing
|
||||||
|
|
||||||
if data.dual_values !== nothing
|
if !isempty(data.dual_values)
|
||||||
dual_values = []
|
dual_values = []
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -280,7 +278,7 @@ function get_constraints(
|
|||||||
length(name) > 0 || continue
|
length(name) > 0 || continue
|
||||||
push!(names, name)
|
push!(names, name)
|
||||||
|
|
||||||
if data.dual_values !== nothing
|
if !isempty(data.dual_values)
|
||||||
push!(dual_values, data.dual_values[constr])
|
push!(dual_values, data.dual_values[constr])
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -355,18 +353,18 @@ end
|
|||||||
|
|
||||||
|
|
||||||
@pydef mutable struct JuMPSolver <: miplearn.solvers.internal.InternalSolver
|
@pydef mutable struct JuMPSolver <: miplearn.solvers.internal.InternalSolver
|
||||||
function __init__(self; optimizer)
|
function __init__(self, optimizer_factory)
|
||||||
self.data = JuMPSolverData(
|
self.data = JuMPSolverData(
|
||||||
nothing, # varname_to_var
|
optimizer_factory,
|
||||||
optimizer,
|
Dict(), # varname_to_var
|
||||||
|
Dict(), # cname_to_constr
|
||||||
nothing, # instance
|
nothing, # instance
|
||||||
nothing, # model
|
nothing, # model
|
||||||
nothing, # bin_vars
|
[], # bin_vars
|
||||||
nothing, # solution
|
[], # solution
|
||||||
nothing, # cname_to_constr
|
[], # reduced_costs
|
||||||
nothing, # reduced_costs
|
Dict(), # dual_values
|
||||||
nothing, # dual_values
|
)
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
add_constraints(self, cf) =
|
add_constraints(self, cf) =
|
||||||
@@ -466,4 +464,4 @@ end
|
|||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
export JuMPSolver, solve!, fit!, add!
|
export JuMPSolver
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ miplearn_tests = pyimport("miplearn.solvers.tests")
|
|||||||
traceback = pyimport("traceback")
|
traceback = pyimport("traceback")
|
||||||
|
|
||||||
@testset "JuMPSolver" begin
|
@testset "JuMPSolver" begin
|
||||||
solver = JuMPSolver(optimizer=Gurobi.Optimizer)
|
solver = JuMPSolver(Gurobi.Optimizer)
|
||||||
try
|
try
|
||||||
miplearn_tests.run_internal_solver_tests(solver)
|
miplearn_tests.run_internal_solver_tests(solver)
|
||||||
catch e
|
catch e
|
||||||
|
|||||||
Reference in New Issue
Block a user