diff --git a/Project.toml b/Project.toml index 0e5281d..b4f017d 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" diff --git a/README.md b/README.md index b0030df..4380454 100644 --- a/README.md +++ b/README.md @@ -147,14 +147,14 @@ for i in 1:600 @feature(...) @category(...) - # Save instances to a file + # Save instances to file instance = JuMPInstance(m) - save("instance-$i.bin", instance) + save("instance-$i.h5", instance) end # Initialize training and test instances -training_instances = [FileInstance("instance-$i.bin") for i in 1:500] -test_instances = [FileInstance("instance-$i.bin") for i in 501:600] +training_instances = [FileInstance("instance-$i.h5") for i in 1:500] +test_instances = [FileInstance("instance-$i.h5") for i in 501:600] # Initialize solver solver = LearningSolver(Cbc.Optimizer) diff --git a/src/instance/file.jl b/src/instance/file.jl index f608454..6f34b17 100644 --- a/src/instance/file.jl +++ b/src/instance/file.jl @@ -8,10 +8,13 @@ mutable struct FileInstance <: Instance py::Union{Nothing,PyCall.PyObject} loaded::Union{Nothing, JuMPInstance} filename::AbstractString + h5::PyCall.PyObject - function FileInstance(filename::String)::FileInstance + function FileInstance(filename::AbstractString)::FileInstance instance = new(nothing, nothing, filename) instance.py = PyFileInstance(instance) + instance.h5 = Hdf5Sample(filename) + instance.filename = filename return instance end end @@ -21,8 +24,14 @@ get_instance_features(instance::FileInstance) = get_instance_features(instance.l get_variable_features(instance::FileInstance) = get_variable_features(instance.loaded) get_variable_categories(instance::FileInstance) = get_variable_categories(instance.loaded) get_constraint_features(instance::FileInstance) = get_constraint_features(instance.loaded) -get_samples(instance::FileInstance) = get_samples(instance.loaded) -create_sample!(instance::FileInstance) = create_sample!(instance.loaded) + +function get_samples(instance::FileInstance) + return [instance.h5] +end + +function create_sample!(instance::FileInstance) + return instance.h5 +end function get_constraint_categories(instance::FileInstance) return get_constraint_categories(instance.loaded) @@ -41,7 +50,6 @@ function free(instance::FileInstance) end function flush(instance::FileInstance) - save(instance.filename, instance.loaded) end function __init_PyFileInstance__() diff --git a/src/instance/jump.jl b/src/instance/jump.jl index 0ecf055..5ae2c2b 100644 --- a/src/instance/jump.jl +++ b/src/instance/jump.jl @@ -3,7 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. using JuMP -using JLD2 +import JSON mutable struct JuMPInstance <: Instance py::Union{Nothing,PyCall.PyObject} @@ -75,42 +75,31 @@ function save(filename::AbstractString, instance::JuMPInstance)::Nothing write_to_file(model, mps_filename) mps = read(mps_filename) - # Pickle instance.py.samples. Ideally, we would use dumps and loads, but this - # causes some issues with PyCall, probably due to automatic type conversions. - samples_filename = tempname() - miplearn.write_pickle_gz(instance.samples, samples_filename) - samples = read(samples_filename) - - # Generate JLD2 file - jldsave( - filename; - miplearn_version="0.2", - mps=mps, - ext=model.ext[:miplearn], - samples=samples, - ) + # Generate HDF5 + h5 = Hdf5Sample(filename, mode="w") + h5.put_scalar("miplearn_version", "0002") + h5.put_bytes("mps", mps) + h5.put_scalar("jump_ext", JSON.json(model.ext[:miplearn])) return end -function _check_miplearn_version(file) - v = file["miplearn_version"] - v == "0.2" || error( +function _check_miplearn_version(h5) + v = h5.get_scalar("miplearn_version") + v == "0002" || error( "The file you are trying to load has been generated by " * - "MIPLearn $(v) and you are currently running MIPLearn 0.2. " * + "MIPLearn $(v) and you are currently running MIPLearn 0002 " * "Reading files generated by different versions of MIPLearn is " * "not currently supported." ) end function load_instance(filename::AbstractString)::JuMPInstance - jldopen(filename, "r") do file - _check_miplearn_version(file) - instance = JuMPInstance(file["mps"], file["ext"]) - samples_filename = tempname() - write(samples_filename, file["samples"]) - instance.samples = miplearn.read_pickle_gz(samples_filename) - return instance - end + h5 = Hdf5Sample(filename) + _check_miplearn_version(h5) + mps = h5.get_bytes("mps") + ext = h5.get_scalar("jump_ext") + instance = JuMPInstance(Vector{UInt8}(mps), JSON.parse(ext)) + return instance end export JuMPInstance, save, load_instance diff --git a/src/solvers/learning.jl b/src/solvers/learning.jl index 75762f9..9ee9b0b 100644 --- a/src/solvers/learning.jl +++ b/src/solvers/learning.jl @@ -3,6 +3,7 @@ # Released under the modified BSD license. See COPYING.md for more details. using Distributed +using JLD2 struct LearningSolver @@ -117,7 +118,6 @@ end function load_solver(filename::AbstractString)::LearningSolver jldopen(filename, "r") do file - _check_miplearn_version(file) solve_py_filename = tempname() write(solve_py_filename, file["solver_py"]) solver_py = miplearn.read_pickle_gz(solve_py_filename) diff --git a/test/instance/file_test.jl b/test/instance/file_test.jl index 4e63ba9..8552d98 100644 --- a/test/instance/file_test.jl +++ b/test/instance/file_test.jl @@ -14,11 +14,15 @@ using Cbc filename = tempname() save(filename, instance) + h5 = MIPLearn.Hdf5Sample(filename) + @test h5.get_scalar("miplearn_version") == "0002" + @test length(h5.get_bytes("mps")) > 0 + @test length(h5.get_scalar("jump_ext")) > 0 + file_instance = FileInstance(filename) solver = LearningSolver(Cbc.Optimizer) solve!(solver, file_instance) - loaded = load_instance(filename) - @test length(loaded.samples) == 1 + @test length(h5.get_vector("mip_var_values")) == 3 end end diff --git a/test/runtests.jl b/test/runtests.jl index 0bf0662..8bd0f37 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,8 +9,8 @@ MIPLearn.setup_logger() @testset "MIPLearn" begin include("fixtures/knapsack.jl") + include("instance/file_test.jl") include("solvers/jump_test.jl") include("solvers/learning_test.jl") - include("instance/file_test.jl") include("utils/benchmark_test.jl") end