mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 08:28:52 -06:00
Use HDF5 for instance files
This commit is contained in:
@@ -8,10 +8,13 @@ mutable struct FileInstance <: Instance
|
||||
py::Union{Nothing,PyCall.PyObject}
|
||||
loaded::Union{Nothing, JuMPInstance}
|
||||
filename::AbstractString
|
||||
h5::PyCall.PyObject
|
||||
|
||||
function FileInstance(filename::String)::FileInstance
|
||||
function FileInstance(filename::AbstractString)::FileInstance
|
||||
instance = new(nothing, nothing, filename)
|
||||
instance.py = PyFileInstance(instance)
|
||||
instance.h5 = Hdf5Sample(filename)
|
||||
instance.filename = filename
|
||||
return instance
|
||||
end
|
||||
end
|
||||
@@ -21,8 +24,14 @@ get_instance_features(instance::FileInstance) = get_instance_features(instance.l
|
||||
get_variable_features(instance::FileInstance) = get_variable_features(instance.loaded)
|
||||
get_variable_categories(instance::FileInstance) = get_variable_categories(instance.loaded)
|
||||
get_constraint_features(instance::FileInstance) = get_constraint_features(instance.loaded)
|
||||
get_samples(instance::FileInstance) = get_samples(instance.loaded)
|
||||
create_sample!(instance::FileInstance) = create_sample!(instance.loaded)
|
||||
|
||||
function get_samples(instance::FileInstance)
|
||||
return [instance.h5]
|
||||
end
|
||||
|
||||
function create_sample!(instance::FileInstance)
|
||||
return instance.h5
|
||||
end
|
||||
|
||||
function get_constraint_categories(instance::FileInstance)
|
||||
return get_constraint_categories(instance.loaded)
|
||||
@@ -41,7 +50,6 @@ function free(instance::FileInstance)
|
||||
end
|
||||
|
||||
function flush(instance::FileInstance)
|
||||
save(instance.filename, instance.loaded)
|
||||
end
|
||||
|
||||
function __init_PyFileInstance__()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
using JuMP
|
||||
using JLD2
|
||||
import JSON
|
||||
|
||||
mutable struct JuMPInstance <: Instance
|
||||
py::Union{Nothing,PyCall.PyObject}
|
||||
@@ -75,42 +75,31 @@ function save(filename::AbstractString, instance::JuMPInstance)::Nothing
|
||||
write_to_file(model, mps_filename)
|
||||
mps = read(mps_filename)
|
||||
|
||||
# Pickle instance.py.samples. Ideally, we would use dumps and loads, but this
|
||||
# causes some issues with PyCall, probably due to automatic type conversions.
|
||||
samples_filename = tempname()
|
||||
miplearn.write_pickle_gz(instance.samples, samples_filename)
|
||||
samples = read(samples_filename)
|
||||
|
||||
# Generate JLD2 file
|
||||
jldsave(
|
||||
filename;
|
||||
miplearn_version="0.2",
|
||||
mps=mps,
|
||||
ext=model.ext[:miplearn],
|
||||
samples=samples,
|
||||
)
|
||||
# Generate HDF5
|
||||
h5 = Hdf5Sample(filename, mode="w")
|
||||
h5.put_scalar("miplearn_version", "0002")
|
||||
h5.put_bytes("mps", mps)
|
||||
h5.put_scalar("jump_ext", JSON.json(model.ext[:miplearn]))
|
||||
return
|
||||
end
|
||||
|
||||
function _check_miplearn_version(file)
|
||||
v = file["miplearn_version"]
|
||||
v == "0.2" || error(
|
||||
function _check_miplearn_version(h5)
|
||||
v = h5.get_scalar("miplearn_version")
|
||||
v == "0002" || error(
|
||||
"The file you are trying to load has been generated by " *
|
||||
"MIPLearn $(v) and you are currently running MIPLearn 0.2. " *
|
||||
"MIPLearn $(v) and you are currently running MIPLearn 0002 " *
|
||||
"Reading files generated by different versions of MIPLearn is " *
|
||||
"not currently supported."
|
||||
)
|
||||
end
|
||||
|
||||
function load_instance(filename::AbstractString)::JuMPInstance
|
||||
jldopen(filename, "r") do file
|
||||
_check_miplearn_version(file)
|
||||
instance = JuMPInstance(file["mps"], file["ext"])
|
||||
samples_filename = tempname()
|
||||
write(samples_filename, file["samples"])
|
||||
instance.samples = miplearn.read_pickle_gz(samples_filename)
|
||||
return instance
|
||||
end
|
||||
h5 = Hdf5Sample(filename)
|
||||
_check_miplearn_version(h5)
|
||||
mps = h5.get_bytes("mps")
|
||||
ext = h5.get_scalar("jump_ext")
|
||||
instance = JuMPInstance(Vector{UInt8}(mps), JSON.parse(ext))
|
||||
return instance
|
||||
end
|
||||
|
||||
export JuMPInstance, save, load_instance
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
using Distributed
|
||||
using JLD2
|
||||
|
||||
|
||||
struct LearningSolver
|
||||
@@ -117,7 +118,6 @@ end
|
||||
|
||||
function load_solver(filename::AbstractString)::LearningSolver
|
||||
jldopen(filename, "r") do file
|
||||
_check_miplearn_version(file)
|
||||
solve_py_filename = tempname()
|
||||
write(solve_py_filename, file["solver_py"])
|
||||
solver_py = miplearn.read_pickle_gz(solve_py_filename)
|
||||
|
||||
Reference in New Issue
Block a user