mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 00:18:51 -06:00
Add write_jld2, reformat code
This commit is contained in:
@@ -8,6 +8,7 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
|
||||
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
|
||||
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
|
||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
|
||||
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
|
||||
@@ -135,9 +135,8 @@ function _create_node(
|
||||
status, obj = solve_relaxation!(mip)
|
||||
if status == :Optimal
|
||||
vals = values(mip, mip.int_vars)
|
||||
fractional_indices = [
|
||||
j for j in 1:length(mip.int_vars) if 1e-6 < vals[j] - floor(vals[j]) < 1 - 1e-6
|
||||
]
|
||||
fractional_indices =
|
||||
[j for j = 1:length(mip.int_vars) if 1e-6 < vals[j] - floor(vals[j]) < 1 - 1e-6]
|
||||
fractional_values = vals[fractional_indices]
|
||||
fractional_variables = mip.int_vars[fractional_indices]
|
||||
else
|
||||
|
||||
@@ -6,12 +6,7 @@ import ..H5File
|
||||
|
||||
using OrderedCollections
|
||||
|
||||
function collect_gmi(
|
||||
mps_filename;
|
||||
optimizer,
|
||||
max_rounds=10,
|
||||
max_cuts_per_round=100,
|
||||
)
|
||||
function collect_gmi(mps_filename; optimizer, max_rounds = 10, max_cuts_per_round = 100)
|
||||
@info mps_filename
|
||||
reset_timer!()
|
||||
|
||||
@@ -24,7 +19,7 @@ function collect_gmi(
|
||||
zip(
|
||||
h5.get_array("static_var_names"),
|
||||
convert(Array{Float64}, h5.get_array("mip_var_values")),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Read optimal value
|
||||
@@ -52,7 +47,7 @@ function collect_gmi(
|
||||
# Read problem
|
||||
model = read_from_file(mps_filename)
|
||||
|
||||
for round in 1:max_rounds
|
||||
for round = 1:max_rounds
|
||||
@info "Round $(round)..."
|
||||
|
||||
stats_time_convert = @elapsed begin
|
||||
@@ -97,22 +92,13 @@ function collect_gmi(
|
||||
basis = get_basis(model_s)
|
||||
sol_frac = get_x(model_s)
|
||||
stats_time_select += @elapsed begin
|
||||
selected_rows = select_gmi_rows(
|
||||
data_s,
|
||||
basis,
|
||||
sol_frac,
|
||||
max_rows=max_cuts_per_round,
|
||||
)
|
||||
selected_rows =
|
||||
select_gmi_rows(data_s, basis, sol_frac, max_rows = max_cuts_per_round)
|
||||
end
|
||||
|
||||
# Compute selected tableau rows
|
||||
stats_time_tableau += @elapsed begin
|
||||
tableau = compute_tableau(
|
||||
data_s,
|
||||
basis,
|
||||
sol_frac,
|
||||
rows=selected_rows,
|
||||
)
|
||||
tableau = compute_tableau(data_s, basis, sol_frac, rows = selected_rows)
|
||||
|
||||
# Assert tableau rows have been computed correctly
|
||||
@assert tableau.lhs * sol_frac ≈ tableau.rhs
|
||||
@@ -153,10 +139,7 @@ function collect_gmi(
|
||||
push!(stats_gap, gap(obj))
|
||||
|
||||
# Store useful cuts; drop useless ones from the problem
|
||||
useful = [
|
||||
abs(shadow_price(c)) > 1e-3
|
||||
for c in constrs
|
||||
]
|
||||
useful = [abs(shadow_price(c)) > 1e-3 for c in constrs]
|
||||
drop = findall(useful .== false)
|
||||
keep = findall(useful .== true)
|
||||
delete.(model, constrs[drop])
|
||||
|
||||
@@ -7,32 +7,28 @@ using TimerOutputs
|
||||
|
||||
@inline frac(x::Float64) = x - floor(x)
|
||||
|
||||
function select_gmi_rows(data, basis, x; max_rows=10, atol=0.001)
|
||||
function select_gmi_rows(data, basis, x; max_rows = 10, atol = 0.001)
|
||||
candidate_rows = [
|
||||
r
|
||||
for r in 1:length(basis.var_basic)
|
||||
if (data.var_types[basis.var_basic[r]] != 'C') && (frac(x[basis.var_basic[r]]) > atol)
|
||||
r for
|
||||
r = 1:length(basis.var_basic) if (data.var_types[basis.var_basic[r]] != 'C') &&
|
||||
(frac(x[basis.var_basic[r]]) > atol)
|
||||
]
|
||||
candidate_vals = frac.(x[basis.var_basic[candidate_rows]])
|
||||
score = abs.(candidate_vals .- 0.5)
|
||||
perm = sortperm(score)
|
||||
return [candidate_rows[perm[i]] for i in 1:min(length(perm), max_rows)]
|
||||
return [candidate_rows[perm[i]] for i = 1:min(length(perm), max_rows)]
|
||||
end
|
||||
|
||||
function compute_gmi(
|
||||
data::ProblemData,
|
||||
tableau::Tableau,
|
||||
tol=1e-8,
|
||||
)::ConstraintSet
|
||||
function compute_gmi(data::ProblemData, tableau::Tableau, tol = 1e-8)::ConstraintSet
|
||||
nrows, ncols = size(tableau.lhs)
|
||||
ub = Float64[Inf for _ in 1:nrows]
|
||||
lb = Float64[0.999 for _ in 1:nrows]
|
||||
ub = Float64[Inf for _ = 1:nrows]
|
||||
lb = Float64[0.999 for _ = 1:nrows]
|
||||
tableau_I, tableau_J, tableau_V = findnz(tableau.lhs)
|
||||
lhs_I = Int[]
|
||||
lhs_J = Int[]
|
||||
lhs_V = Float64[]
|
||||
@timeit "Compute coefficients" begin
|
||||
for k in 1:nnz(tableau.lhs)
|
||||
for k = 1:nnz(tableau.lhs)
|
||||
i::Int = tableau_I[k]
|
||||
v::Float64 = 0.0
|
||||
alpha_j = frac(tableau_V[k])
|
||||
@@ -61,12 +57,8 @@ function compute_gmi(
|
||||
return ConstraintSet(; lhs, ub, lb)
|
||||
end
|
||||
|
||||
function assert_cuts_off(
|
||||
cuts::ConstraintSet,
|
||||
x::Vector{Float64},
|
||||
tol=1e-6
|
||||
)
|
||||
for i in 1:length(cuts.lb)
|
||||
function assert_cuts_off(cuts::ConstraintSet, x::Vector{Float64}, tol = 1e-6)
|
||||
for i = 1:length(cuts.lb)
|
||||
val = cuts.lhs[i, :]' * x
|
||||
if (val <= cuts.ub[i] - tol) && (val >= cuts.lb[i] + tol)
|
||||
throw(ErrorException("inequality fails to cut off fractional solution"))
|
||||
@@ -74,17 +66,17 @@ function assert_cuts_off(
|
||||
end
|
||||
end
|
||||
|
||||
function assert_does_not_cut_off(
|
||||
cuts::ConstraintSet,
|
||||
x::Vector{Float64};
|
||||
tol=1e-6
|
||||
)
|
||||
for i in 1:length(cuts.lb)
|
||||
function assert_does_not_cut_off(cuts::ConstraintSet, x::Vector{Float64}; tol = 1e-6)
|
||||
for i = 1:length(cuts.lb)
|
||||
val = cuts.lhs[i, :]' * x
|
||||
ub = cuts.ub[i]
|
||||
lb = cuts.lb[i]
|
||||
if (val >= ub) || (val <= lb)
|
||||
throw(ErrorException("inequality $i cuts off integer solution ($lb <= $val <= $ub)"))
|
||||
throw(
|
||||
ErrorException(
|
||||
"inequality $i cuts off integer solution ($lb <= $val <= $ub)",
|
||||
),
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@@ -27,11 +27,7 @@ function ProblemData(model::Model)::ProblemData
|
||||
for (ftype, stype) in list_of_constraint_types(model)
|
||||
for constr in all_constraints(model, ftype, stype)
|
||||
cset = MOI.get(constr.model.moi_backend, MOI.ConstraintSet(), constr.index)
|
||||
cf = MOI.get(
|
||||
constr.model.moi_backend,
|
||||
MOI.ConstraintFunction(),
|
||||
constr.index,
|
||||
)
|
||||
cf = MOI.get(constr.model.moi_backend, MOI.ConstraintFunction(), constr.index)
|
||||
if ftype == VariableRef
|
||||
var_idx = cf.value
|
||||
if stype == MOI.Integer || stype == MOI.ZeroOne
|
||||
@@ -79,13 +75,7 @@ function ProblemData(model::Model)::ProblemData
|
||||
|
||||
n = length(vars)
|
||||
m = constr_index - 1
|
||||
constr_lhs = sparse(
|
||||
constr_lhs_rows,
|
||||
constr_lhs_cols,
|
||||
constr_lhs_values,
|
||||
m,
|
||||
n,
|
||||
)
|
||||
constr_lhs = sparse(constr_lhs_rows, constr_lhs_cols, constr_lhs_values, m, n)
|
||||
|
||||
@assert length(obj) == n
|
||||
@assert length(var_lb) == n
|
||||
@@ -96,7 +86,7 @@ function ProblemData(model::Model)::ProblemData
|
||||
@assert length(constr_ub) == m
|
||||
|
||||
return ProblemData(
|
||||
obj_offset=0.0;
|
||||
obj_offset = 0.0;
|
||||
obj,
|
||||
constr_lb,
|
||||
constr_ub,
|
||||
@@ -104,11 +94,11 @@ function ProblemData(model::Model)::ProblemData
|
||||
var_lb,
|
||||
var_ub,
|
||||
var_types,
|
||||
var_names
|
||||
var_names,
|
||||
)
|
||||
end
|
||||
|
||||
function to_model(data::ProblemData, tol=1e-6)::Model
|
||||
function to_model(data::ProblemData, tol = 1e-6)::Model
|
||||
model = Model()
|
||||
|
||||
# Variables
|
||||
@@ -153,7 +143,7 @@ function add_constraint_set(model::JuMP.Model, cs::ConstraintSet)
|
||||
vars = all_variables(model)
|
||||
nrows, _ = size(cs.lhs)
|
||||
constrs = []
|
||||
for i in 1:nrows
|
||||
for i = 1:nrows
|
||||
c = nothing
|
||||
if isinf(cs.ub[i])
|
||||
c = @constraint(model, cs.lb[i] <= dot(cs.lhs[i, :], vars))
|
||||
|
||||
@@ -17,17 +17,17 @@ Base.@kwdef mutable struct ProblemData
|
||||
end
|
||||
|
||||
Base.@kwdef mutable struct Tableau
|
||||
obj
|
||||
lhs
|
||||
rhs
|
||||
z
|
||||
obj::Any
|
||||
lhs::Any
|
||||
rhs::Any
|
||||
z::Any
|
||||
end
|
||||
|
||||
Base.@kwdef mutable struct Basis
|
||||
var_basic
|
||||
var_nonbasic
|
||||
constr_basic
|
||||
constr_nonbasic
|
||||
var_basic::Any
|
||||
var_nonbasic::Any
|
||||
constr_basic::Any
|
||||
constr_nonbasic::Any
|
||||
end
|
||||
|
||||
Base.@kwdef mutable struct ConstraintSet
|
||||
|
||||
@@ -56,8 +56,8 @@ function compute_tableau(
|
||||
data::ProblemData,
|
||||
basis::Basis,
|
||||
x::Vector{Float64};
|
||||
rows::Union{Vector{Int},Nothing}=nothing,
|
||||
tol=1e-8
|
||||
rows::Union{Vector{Int},Nothing} = nothing,
|
||||
tol = 1e-8,
|
||||
)::Tableau
|
||||
@timeit "Split data" begin
|
||||
nrows, ncols = size(data.constr_lhs)
|
||||
@@ -77,7 +77,7 @@ function compute_tableau(
|
||||
tableau_lhs_I = Int[]
|
||||
tableau_lhs_J = Int[]
|
||||
tableau_lhs_V = Float64[]
|
||||
for k in 1:length(rows)
|
||||
for k = 1:length(rows)
|
||||
@timeit "Prepare inputs" begin
|
||||
i = rows[k]
|
||||
e = zeros(nrows)
|
||||
@@ -100,13 +100,8 @@ function compute_tableau(
|
||||
end
|
||||
end
|
||||
end
|
||||
tableau_lhs = sparse(
|
||||
tableau_lhs_I,
|
||||
tableau_lhs_J,
|
||||
tableau_lhs_V,
|
||||
length(rows),
|
||||
ncols,
|
||||
)
|
||||
tableau_lhs =
|
||||
sparse(tableau_lhs_I, tableau_lhs_J, tableau_lhs_V, length(rows), ncols)
|
||||
end
|
||||
|
||||
@timeit "Compute tableau RHS" begin
|
||||
@@ -120,10 +115,10 @@ function compute_tableau(
|
||||
end
|
||||
|
||||
return Tableau(
|
||||
obj=tableau_obj,
|
||||
lhs=tableau_lhs,
|
||||
rhs=tableau_rhs,
|
||||
z=dot(data.obj, x),
|
||||
obj = tableau_obj,
|
||||
lhs = tableau_lhs,
|
||||
rhs = tableau_rhs,
|
||||
z = dot(data.obj, x),
|
||||
)
|
||||
end
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ function _isbounded(x)
|
||||
return true
|
||||
end
|
||||
|
||||
function backwards!(transforms::Vector{Transform}, m::ConstraintSet; tol=1e-8)
|
||||
function backwards!(transforms::Vector{Transform}, m::ConstraintSet; tol = 1e-8)
|
||||
for t in reverse(transforms)
|
||||
backwards!(t, m)
|
||||
end
|
||||
@@ -24,7 +24,7 @@ function backwards!(transforms::Vector{Transform}, m::ConstraintSet; tol=1e-8)
|
||||
end
|
||||
end
|
||||
|
||||
function backwards(transforms::Vector{Transform}, m::ConstraintSet; tol=1e-8)
|
||||
function backwards(transforms::Vector{Transform}, m::ConstraintSet; tol = 1e-8)
|
||||
m2 = deepcopy(m)
|
||||
backwards!(transforms, m2; tol)
|
||||
return m2
|
||||
@@ -68,7 +68,7 @@ Base.@kwdef mutable struct MoveVarUpperBoundsToConstrs <: Transform end
|
||||
function forward!(t::MoveVarUpperBoundsToConstrs, data::ProblemData)
|
||||
_, ncols = size(data.constr_lhs)
|
||||
data.constr_lhs = [data.constr_lhs; I]
|
||||
data.constr_lb = [data.constr_lb; [-Inf for _ in 1:ncols]]
|
||||
data.constr_lb = [data.constr_lb; [-Inf for _ = 1:ncols]]
|
||||
data.constr_ub = [data.constr_ub; data.var_ub]
|
||||
data.var_ub .= Inf
|
||||
end
|
||||
@@ -98,9 +98,9 @@ end
|
||||
function forward!(t::AddSlackVariables, data::ProblemData)
|
||||
nrows, ncols = size(data.constr_lhs)
|
||||
isequality = abs.(data.constr_ub .- data.constr_lb) .< 1e-6
|
||||
eq = [i for i in 1:nrows if isequality[i]]
|
||||
ge = [i for i in 1:nrows if isfinite(data.constr_lb[i]) && !isequality[i]]
|
||||
le = [i for i in 1:nrows if isfinite(data.constr_ub[i]) && !isequality[i]]
|
||||
eq = [i for i = 1:nrows if isequality[i]]
|
||||
ge = [i for i = 1:nrows if isfinite(data.constr_lb[i]) && !isequality[i]]
|
||||
le = [i for i = 1:nrows if isfinite(data.constr_ub[i]) && !isequality[i]]
|
||||
EQ, GE, LE = length(eq), length(ge), length(le)
|
||||
|
||||
t.M1 = [
|
||||
@@ -128,8 +128,8 @@ function forward!(t::AddSlackVariables, data::ProblemData)
|
||||
data.obj = [data.obj; zeros(GE + LE)]
|
||||
data.var_lb = [data.var_lb; zeros(GE + LE)]
|
||||
data.var_ub = [data.var_ub; [Inf for _ = 1:(GE+LE)]]
|
||||
data.var_names = [data.var_names; ["__s$i" for i in 1:(GE+LE)]]
|
||||
data.var_types = [data.var_types; ['C' for _ in 1:(GE+LE)]]
|
||||
data.var_names = [data.var_names; ["__s$i" for i = 1:(GE+LE)]]
|
||||
data.var_types = [data.var_types; ['C' for _ = 1:(GE+LE)]]
|
||||
data.constr_lb = [
|
||||
data.constr_lb[eq]
|
||||
data.constr_lb[ge]
|
||||
@@ -157,15 +157,15 @@ end
|
||||
Base.@kwdef mutable struct SplitFreeVars <: Transform
|
||||
F::Int = 0
|
||||
B::Int = 0
|
||||
free::Vector{Int}=[]
|
||||
others::Vector{Int}=[]
|
||||
free::Vector{Int} = []
|
||||
others::Vector{Int} = []
|
||||
end
|
||||
|
||||
function forward!(t::SplitFreeVars, data::ProblemData)
|
||||
lhs = data.constr_lhs
|
||||
_, ncols = size(lhs)
|
||||
free = [i for i in 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])]
|
||||
others = [i for i in 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])]
|
||||
free = [i for i = 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])]
|
||||
others = [i for i = 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])]
|
||||
t.F = length(free)
|
||||
t.B = length(others)
|
||||
t.free, t.others = free, others
|
||||
@@ -200,7 +200,7 @@ end
|
||||
function backwards!(t::SplitFreeVars, c::ConstraintSet)
|
||||
# Convert GE constraints into LE
|
||||
nrows, _ = size(c.lhs)
|
||||
ge = [i for i in 1:nrows if isfinite(c.lb[i])]
|
||||
ge = [i for i = 1:nrows if isfinite(c.lb[i])]
|
||||
c.ub[ge], c.lb[ge] = -c.lb[ge], -c.ub[ge]
|
||||
c.lhs[ge, :] *= -1
|
||||
|
||||
@@ -209,8 +209,8 @@ function backwards!(t::SplitFreeVars, c::ConstraintSet)
|
||||
|
||||
# Take minimum (weakest) coefficient
|
||||
B, F = t.B, t.F
|
||||
for i in 1:F
|
||||
c.lhs[:, B + i] = min.(c.lhs[:, B + i], -c.lhs[:, B + F + i])
|
||||
for i = 1:F
|
||||
c.lhs[:, B+i] = min.(c.lhs[:, B+i], -c.lhs[:, B+F+i])
|
||||
end
|
||||
c.lhs = c.lhs[:, 1:(B+F)]
|
||||
end
|
||||
@@ -231,7 +231,7 @@ end
|
||||
|
||||
function forward!(t::FlipUnboundedBelowVars, data::ProblemData)
|
||||
_, ncols = size(data.constr_lhs)
|
||||
for i in 1:ncols
|
||||
for i = 1:ncols
|
||||
if isfinite(data.var_lb[i])
|
||||
continue
|
||||
end
|
||||
|
||||
26
src/io.jl
26
src/io.jl
@@ -2,6 +2,9 @@
|
||||
# Copyright (C) 2020-2023, UChicago Argonne, LLC. All rights reserved.
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
using Printf
|
||||
using JLD2
|
||||
|
||||
global H5File = PyNULL()
|
||||
global write_pkl_gz = PyNULL()
|
||||
global read_pkl_gz = PyNULL()
|
||||
@@ -36,8 +39,27 @@ end
|
||||
function PyObject(m::SparseMatrixCSC)
|
||||
pyimport("scipy.sparse").csc_matrix(
|
||||
(m.nzval, m.rowval .- 1, m.colptr .- 1),
|
||||
shape = size(m),
|
||||
shape=size(m),
|
||||
).tocoo()
|
||||
end
|
||||
|
||||
export H5File, write_pkl_gz, read_pkl_gz
|
||||
function write_jld2(
|
||||
objs::Vector,
|
||||
dirname::AbstractString;
|
||||
prefix::AbstractString=""
|
||||
)::Vector{String}
|
||||
mkpath(dirname)
|
||||
filenames = [@sprintf("%s/%s%05d.jld2", dirname, prefix, i) for i = 1:length(objs)]
|
||||
for (i, obj) in enumerate(objs)
|
||||
jldsave(filenames[i]; obj)
|
||||
end
|
||||
return filenames
|
||||
end
|
||||
|
||||
function read_jld2(filename::AbstractString)::Any
|
||||
jldopen(filename, "r") do file
|
||||
return file["obj"]
|
||||
end
|
||||
end
|
||||
|
||||
export H5File, write_pkl_gz, read_pkl_gz, write_jld2, read_jld2
|
||||
|
||||
@@ -278,8 +278,10 @@ function _set_warm_starts(model::JuMP.Model, var_names, var_values, stats)
|
||||
n_starts == 1 || error("JuMP does not support multiple warm starts")
|
||||
vars = [variable_by_name(model, v) for v in var_names]
|
||||
for (i, var) in enumerate(vars)
|
||||
if isfinite(var_values[i])
|
||||
set_start_value(var, var_values[i])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function _write(model::JuMP.Model, filename)
|
||||
|
||||
@@ -8,6 +8,7 @@ Clp = "e2554f3b-3117-50c0-817c-e040a3ddf72d"
|
||||
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
|
||||
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
|
||||
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
|
||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
|
||||
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
|
||||
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
|
||||
@@ -15,3 +16,7 @@ MIPLearn = "2b1277c3-b477-4c49-a15e-7ba350325c68"
|
||||
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
|
||||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[compat]
|
||||
JuMP = "1"
|
||||
julia = "1"
|
||||
|
||||
BIN
test/fixtures/bell5.h5
vendored
BIN
test/fixtures/bell5.h5
vendored
Binary file not shown.
@@ -3,9 +3,16 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
using MIPLearn
|
||||
using JLD2
|
||||
|
||||
struct _TestStruct
|
||||
n::Int
|
||||
q::Vector{Float64}
|
||||
end
|
||||
|
||||
function test_io()
|
||||
test_pkl_gz()
|
||||
test_jld2()
|
||||
test_h5()
|
||||
end
|
||||
|
||||
@@ -32,6 +39,26 @@ function test_h5()
|
||||
h5.close()
|
||||
end
|
||||
|
||||
function test_jld2()
|
||||
dirname = mktempdir()
|
||||
data = [
|
||||
_TestStruct(1, [0.0, 0.0, 0.0]),
|
||||
_TestStruct(2, [1.0, 2.0, 3.0]),
|
||||
_TestStruct(3, [3.0, 3.0, 3.0]),
|
||||
]
|
||||
filenames = write_jld2(data, dirname, prefix="obj")
|
||||
@test all(
|
||||
filenames .==
|
||||
["$dirname/obj00001.jld2", "$dirname/obj00002.jld2", "$dirname/obj00003.jld2"],
|
||||
)
|
||||
@assert isfile("$dirname/obj00001.jld2")
|
||||
@assert isfile("$dirname/obj00002.jld2")
|
||||
@assert isfile("$dirname/obj00003.jld2")
|
||||
recovered = read_jld2("$dirname/obj00002.jld2")
|
||||
@test recovered.n == 2
|
||||
@test all(recovered.q .== [1.0, 2.0, 3.0])
|
||||
end
|
||||
|
||||
function _test_roundtrip_scalar(h5, original)
|
||||
h5.put_scalar("key", original)
|
||||
recovered = h5.get_scalar("key")
|
||||
|
||||
Reference in New Issue
Block a user