mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
Save and load LearningSolver; fix parallel_solve
This commit is contained in:
@@ -90,7 +90,7 @@ function save(filename::AbstractString, instance::JuMPInstance)::Nothing
|
||||
# Generate JLD2 file
|
||||
jldsave(
|
||||
filename;
|
||||
miplearn_version=0.2,
|
||||
miplearn_version="0.2",
|
||||
mps=mps,
|
||||
ext=ext_names,
|
||||
py_samples=py_samples,
|
||||
@@ -100,16 +100,23 @@ function save(filename::AbstractString, instance::JuMPInstance)::Nothing
|
||||
return
|
||||
end
|
||||
|
||||
function _check_miplearn_version(file)
|
||||
v = file["miplearn_version"]
|
||||
v == "0.2" || error(
|
||||
"The file you are trying to load has been generated by " *
|
||||
"MIPLearn $(v) and you are currently running MIPLearn 0.2. " *
|
||||
"Reading files generated by different versions of MIPLearn is " *
|
||||
"not currently supported."
|
||||
)
|
||||
end
|
||||
|
||||
|
||||
function load_jump_instance(filename::AbstractString)::JuMPInstance
|
||||
@info "Reading: $filename"
|
||||
instance = nothing
|
||||
time = @elapsed begin
|
||||
jldopen(filename, "r") do file
|
||||
file["miplearn_version"] == 0.2 || error(
|
||||
"MIPLearn version 0.2 cannot read instance files generated by " *
|
||||
"version $(file["miplearn_version"])."
|
||||
)
|
||||
_check_miplearn_version(file)
|
||||
|
||||
# Convert MPS to JuMP
|
||||
mps_filename = "$(tempname()).mps.gz"
|
||||
|
||||
@@ -37,16 +37,68 @@ end
|
||||
function parallel_solve!(solver::LearningSolver, instances::Vector{FileInstance})
|
||||
filenames = [instance.filename for instance in instances]
|
||||
optimizer_factory = solver.optimizer_factory
|
||||
solver_filename = tempname()
|
||||
save(solver_filename, solver)
|
||||
@sync @distributed for filename in filenames
|
||||
s = LearningSolver(optimizer_factory)
|
||||
s = load_solver(solver_filename)
|
||||
solve!(s, FileInstance(filename))
|
||||
nothing
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function save(filename::AbstractString, solver::LearningSolver)
|
||||
@info "Writing: $filename"
|
||||
time = @elapsed begin
|
||||
# Pickle solver.py
|
||||
internal_solver = solver.py.internal_solver
|
||||
internal_solver_prototype = solver.py.internal_solver_prototype
|
||||
solver.py.internal_solver = nothing
|
||||
solver.py.internal_solver_prototype = nothing
|
||||
solver_py_filename = tempname()
|
||||
miplearn.write_pickle_gz(solver.py, solver_py_filename, quiet=true)
|
||||
solver_py = read(solver_py_filename)
|
||||
solver.py.internal_solver = internal_solver
|
||||
solver.py.internal_solver_prototype = internal_solver_prototype
|
||||
|
||||
jldsave(
|
||||
filename;
|
||||
miplearn_version="0.2",
|
||||
solver_py=solver_py,
|
||||
optimizer_factory=solver.optimizer_factory,
|
||||
)
|
||||
end
|
||||
@info @sprintf("File written in %.2f seconds", time)
|
||||
return
|
||||
end
|
||||
|
||||
|
||||
function load_solver(filename::AbstractString)::LearningSolver
|
||||
@info "Reading: $filename"
|
||||
solver = nothing
|
||||
time = @elapsed begin
|
||||
jldopen(filename, "r") do file
|
||||
_check_miplearn_version(file)
|
||||
solve_py_filename = tempname()
|
||||
write(solve_py_filename, file["solver_py"])
|
||||
solver_py = miplearn.read_pickle_gz(solve_py_filename, quiet=true)
|
||||
internal_solver = JuMPSolver(file["optimizer_factory"])
|
||||
solver_py.internal_solver_prototype = internal_solver
|
||||
solver = LearningSolver(
|
||||
solver_py,
|
||||
file["optimizer_factory"],
|
||||
)
|
||||
end
|
||||
end
|
||||
@info @sprintf("File read in %.2f seconds", time)
|
||||
return solver
|
||||
end
|
||||
|
||||
|
||||
export Instance,
|
||||
LearningSolver,
|
||||
solve!,
|
||||
fit!,
|
||||
parallel_solve!
|
||||
parallel_solve!,
|
||||
save,
|
||||
load_solver
|
||||
|
||||
@@ -56,5 +56,16 @@ using Gurobi
|
||||
solver = LearningSolver(Gurobi.Optimizer)
|
||||
instance = JuMPInstance(model)
|
||||
stats = solve!(solver, instance)
|
||||
@test stats["mip_lower_bound"] == 2.0
|
||||
end
|
||||
|
||||
@testset "Save and load" begin
|
||||
solver = LearningSolver(Gurobi.Optimizer)
|
||||
solver.py.components = "Placeholder"
|
||||
filename = tempname()
|
||||
save(filename, solver)
|
||||
@test isfile(filename)
|
||||
loaded = load_solver(filename)
|
||||
@test loaded.py.components == "Placeholder"
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user