Add write_jld2, reformat code

master
Alinson S. Xavier 2 years ago
parent b82a984ab1
commit d6025c5f4a
Signed by: isoron
GPG Key ID: 0DA8E4B9E1109DCA

@ -8,6 +8,7 @@ Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572" JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434" KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

@ -135,9 +135,8 @@ function _create_node(
status, obj = solve_relaxation!(mip) status, obj = solve_relaxation!(mip)
if status == :Optimal if status == :Optimal
vals = values(mip, mip.int_vars) vals = values(mip, mip.int_vars)
fractional_indices = [ fractional_indices =
j for j in 1:length(mip.int_vars) if 1e-6 < vals[j] - floor(vals[j]) < 1 - 1e-6 [j for j = 1:length(mip.int_vars) if 1e-6 < vals[j] - floor(vals[j]) < 1 - 1e-6]
]
fractional_values = vals[fractional_indices] fractional_values = vals[fractional_indices]
fractional_variables = mip.int_vars[fractional_indices] fractional_variables = mip.int_vars[fractional_indices]
else else

@ -6,12 +6,7 @@ import ..H5File
using OrderedCollections using OrderedCollections
function collect_gmi( function collect_gmi(mps_filename; optimizer, max_rounds = 10, max_cuts_per_round = 100)
mps_filename;
optimizer,
max_rounds=10,
max_cuts_per_round=100,
)
@info mps_filename @info mps_filename
reset_timer!() reset_timer!()
@ -24,7 +19,7 @@ function collect_gmi(
zip( zip(
h5.get_array("static_var_names"), h5.get_array("static_var_names"),
convert(Array{Float64}, h5.get_array("mip_var_values")), convert(Array{Float64}, h5.get_array("mip_var_values")),
) ),
) )
# Read optimal value # Read optimal value
@ -52,7 +47,7 @@ function collect_gmi(
# Read problem # Read problem
model = read_from_file(mps_filename) model = read_from_file(mps_filename)
for round in 1:max_rounds for round = 1:max_rounds
@info "Round $(round)..." @info "Round $(round)..."
stats_time_convert = @elapsed begin stats_time_convert = @elapsed begin
@ -97,22 +92,13 @@ function collect_gmi(
basis = get_basis(model_s) basis = get_basis(model_s)
sol_frac = get_x(model_s) sol_frac = get_x(model_s)
stats_time_select += @elapsed begin stats_time_select += @elapsed begin
selected_rows = select_gmi_rows( selected_rows =
data_s, select_gmi_rows(data_s, basis, sol_frac, max_rows = max_cuts_per_round)
basis,
sol_frac,
max_rows=max_cuts_per_round,
)
end end
# Compute selected tableau rows # Compute selected tableau rows
stats_time_tableau += @elapsed begin stats_time_tableau += @elapsed begin
tableau = compute_tableau( tableau = compute_tableau(data_s, basis, sol_frac, rows = selected_rows)
data_s,
basis,
sol_frac,
rows=selected_rows,
)
# Assert tableau rows have been computed correctly # Assert tableau rows have been computed correctly
@assert tableau.lhs * sol_frac tableau.rhs @assert tableau.lhs * sol_frac tableau.rhs
@ -153,10 +139,7 @@ function collect_gmi(
push!(stats_gap, gap(obj)) push!(stats_gap, gap(obj))
# Store useful cuts; drop useless ones from the problem # Store useful cuts; drop useless ones from the problem
useful = [ useful = [abs(shadow_price(c)) > 1e-3 for c in constrs]
abs(shadow_price(c)) > 1e-3
for c in constrs
]
drop = findall(useful .== false) drop = findall(useful .== false)
keep = findall(useful .== true) keep = findall(useful .== true)
delete.(model, constrs[drop]) delete.(model, constrs[drop])

@ -7,32 +7,28 @@ using TimerOutputs
@inline frac(x::Float64) = x - floor(x) @inline frac(x::Float64) = x - floor(x)
function select_gmi_rows(data, basis, x; max_rows=10, atol=0.001) function select_gmi_rows(data, basis, x; max_rows = 10, atol = 0.001)
candidate_rows = [ candidate_rows = [
r r for
for r in 1:length(basis.var_basic) r = 1:length(basis.var_basic) if (data.var_types[basis.var_basic[r]] != 'C') &&
if (data.var_types[basis.var_basic[r]] != 'C') && (frac(x[basis.var_basic[r]]) > atol) (frac(x[basis.var_basic[r]]) > atol)
] ]
candidate_vals = frac.(x[basis.var_basic[candidate_rows]]) candidate_vals = frac.(x[basis.var_basic[candidate_rows]])
score = abs.(candidate_vals .- 0.5) score = abs.(candidate_vals .- 0.5)
perm = sortperm(score) perm = sortperm(score)
return [candidate_rows[perm[i]] for i in 1:min(length(perm), max_rows)] return [candidate_rows[perm[i]] for i = 1:min(length(perm), max_rows)]
end end
function compute_gmi( function compute_gmi(data::ProblemData, tableau::Tableau, tol = 1e-8)::ConstraintSet
data::ProblemData,
tableau::Tableau,
tol=1e-8,
)::ConstraintSet
nrows, ncols = size(tableau.lhs) nrows, ncols = size(tableau.lhs)
ub = Float64[Inf for _ in 1:nrows] ub = Float64[Inf for _ = 1:nrows]
lb = Float64[0.999 for _ in 1:nrows] lb = Float64[0.999 for _ = 1:nrows]
tableau_I, tableau_J, tableau_V = findnz(tableau.lhs) tableau_I, tableau_J, tableau_V = findnz(tableau.lhs)
lhs_I = Int[] lhs_I = Int[]
lhs_J = Int[] lhs_J = Int[]
lhs_V = Float64[] lhs_V = Float64[]
@timeit "Compute coefficients" begin @timeit "Compute coefficients" begin
for k in 1:nnz(tableau.lhs) for k = 1:nnz(tableau.lhs)
i::Int = tableau_I[k] i::Int = tableau_I[k]
v::Float64 = 0.0 v::Float64 = 0.0
alpha_j = frac(tableau_V[k]) alpha_j = frac(tableau_V[k])
@ -61,12 +57,8 @@ function compute_gmi(
return ConstraintSet(; lhs, ub, lb) return ConstraintSet(; lhs, ub, lb)
end end
function assert_cuts_off( function assert_cuts_off(cuts::ConstraintSet, x::Vector{Float64}, tol = 1e-6)
cuts::ConstraintSet, for i = 1:length(cuts.lb)
x::Vector{Float64},
tol=1e-6
)
for i in 1:length(cuts.lb)
val = cuts.lhs[i, :]' * x val = cuts.lhs[i, :]' * x
if (val <= cuts.ub[i] - tol) && (val >= cuts.lb[i] + tol) if (val <= cuts.ub[i] - tol) && (val >= cuts.lb[i] + tol)
throw(ErrorException("inequality fails to cut off fractional solution")) throw(ErrorException("inequality fails to cut off fractional solution"))
@ -74,17 +66,17 @@ function assert_cuts_off(
end end
end end
function assert_does_not_cut_off( function assert_does_not_cut_off(cuts::ConstraintSet, x::Vector{Float64}; tol = 1e-6)
cuts::ConstraintSet, for i = 1:length(cuts.lb)
x::Vector{Float64};
tol=1e-6
)
for i in 1:length(cuts.lb)
val = cuts.lhs[i, :]' * x val = cuts.lhs[i, :]' * x
ub = cuts.ub[i] ub = cuts.ub[i]
lb = cuts.lb[i] lb = cuts.lb[i]
if (val >= ub) || (val <= lb) if (val >= ub) || (val <= lb)
throw(ErrorException("inequality $i cuts off integer solution ($lb <= $val <= $ub)")) throw(
ErrorException(
"inequality $i cuts off integer solution ($lb <= $val <= $ub)",
),
)
end end
end end
end end

@ -27,11 +27,7 @@ function ProblemData(model::Model)::ProblemData
for (ftype, stype) in list_of_constraint_types(model) for (ftype, stype) in list_of_constraint_types(model)
for constr in all_constraints(model, ftype, stype) for constr in all_constraints(model, ftype, stype)
cset = MOI.get(constr.model.moi_backend, MOI.ConstraintSet(), constr.index) cset = MOI.get(constr.model.moi_backend, MOI.ConstraintSet(), constr.index)
cf = MOI.get( cf = MOI.get(constr.model.moi_backend, MOI.ConstraintFunction(), constr.index)
constr.model.moi_backend,
MOI.ConstraintFunction(),
constr.index,
)
if ftype == VariableRef if ftype == VariableRef
var_idx = cf.value var_idx = cf.value
if stype == MOI.Integer || stype == MOI.ZeroOne if stype == MOI.Integer || stype == MOI.ZeroOne
@ -79,13 +75,7 @@ function ProblemData(model::Model)::ProblemData
n = length(vars) n = length(vars)
m = constr_index - 1 m = constr_index - 1
constr_lhs = sparse( constr_lhs = sparse(constr_lhs_rows, constr_lhs_cols, constr_lhs_values, m, n)
constr_lhs_rows,
constr_lhs_cols,
constr_lhs_values,
m,
n,
)
@assert length(obj) == n @assert length(obj) == n
@assert length(var_lb) == n @assert length(var_lb) == n
@ -96,7 +86,7 @@ function ProblemData(model::Model)::ProblemData
@assert length(constr_ub) == m @assert length(constr_ub) == m
return ProblemData( return ProblemData(
obj_offset=0.0; obj_offset = 0.0;
obj, obj,
constr_lb, constr_lb,
constr_ub, constr_ub,
@ -104,11 +94,11 @@ function ProblemData(model::Model)::ProblemData
var_lb, var_lb,
var_ub, var_ub,
var_types, var_types,
var_names var_names,
) )
end end
function to_model(data::ProblemData, tol=1e-6)::Model function to_model(data::ProblemData, tol = 1e-6)::Model
model = Model() model = Model()
# Variables # Variables
@ -153,7 +143,7 @@ function add_constraint_set(model::JuMP.Model, cs::ConstraintSet)
vars = all_variables(model) vars = all_variables(model)
nrows, _ = size(cs.lhs) nrows, _ = size(cs.lhs)
constrs = [] constrs = []
for i in 1:nrows for i = 1:nrows
c = nothing c = nothing
if isinf(cs.ub[i]) if isinf(cs.ub[i])
c = @constraint(model, cs.lb[i] <= dot(cs.lhs[i, :], vars)) c = @constraint(model, cs.lb[i] <= dot(cs.lhs[i, :], vars))

@ -17,17 +17,17 @@ Base.@kwdef mutable struct ProblemData
end end
Base.@kwdef mutable struct Tableau Base.@kwdef mutable struct Tableau
obj obj::Any
lhs lhs::Any
rhs rhs::Any
z z::Any
end end
Base.@kwdef mutable struct Basis Base.@kwdef mutable struct Basis
var_basic var_basic::Any
var_nonbasic var_nonbasic::Any
constr_basic constr_basic::Any
constr_nonbasic constr_nonbasic::Any
end end
Base.@kwdef mutable struct ConstraintSet Base.@kwdef mutable struct ConstraintSet

@ -56,8 +56,8 @@ function compute_tableau(
data::ProblemData, data::ProblemData,
basis::Basis, basis::Basis,
x::Vector{Float64}; x::Vector{Float64};
rows::Union{Vector{Int},Nothing}=nothing, rows::Union{Vector{Int},Nothing} = nothing,
tol=1e-8 tol = 1e-8,
)::Tableau )::Tableau
@timeit "Split data" begin @timeit "Split data" begin
nrows, ncols = size(data.constr_lhs) nrows, ncols = size(data.constr_lhs)
@ -77,7 +77,7 @@ function compute_tableau(
tableau_lhs_I = Int[] tableau_lhs_I = Int[]
tableau_lhs_J = Int[] tableau_lhs_J = Int[]
tableau_lhs_V = Float64[] tableau_lhs_V = Float64[]
for k in 1:length(rows) for k = 1:length(rows)
@timeit "Prepare inputs" begin @timeit "Prepare inputs" begin
i = rows[k] i = rows[k]
e = zeros(nrows) e = zeros(nrows)
@ -100,13 +100,8 @@ function compute_tableau(
end end
end end
end end
tableau_lhs = sparse( tableau_lhs =
tableau_lhs_I, sparse(tableau_lhs_I, tableau_lhs_J, tableau_lhs_V, length(rows), ncols)
tableau_lhs_J,
tableau_lhs_V,
length(rows),
ncols,
)
end end
@timeit "Compute tableau RHS" begin @timeit "Compute tableau RHS" begin
@ -120,10 +115,10 @@ function compute_tableau(
end end
return Tableau( return Tableau(
obj=tableau_obj, obj = tableau_obj,
lhs=tableau_lhs, lhs = tableau_lhs,
rhs=tableau_rhs, rhs = tableau_rhs,
z=dot(data.obj, x), z = dot(data.obj, x),
) )
end end

@ -13,7 +13,7 @@ function _isbounded(x)
return true return true
end end
function backwards!(transforms::Vector{Transform}, m::ConstraintSet; tol=1e-8) function backwards!(transforms::Vector{Transform}, m::ConstraintSet; tol = 1e-8)
for t in reverse(transforms) for t in reverse(transforms)
backwards!(t, m) backwards!(t, m)
end end
@ -24,7 +24,7 @@ function backwards!(transforms::Vector{Transform}, m::ConstraintSet; tol=1e-8)
end end
end end
function backwards(transforms::Vector{Transform}, m::ConstraintSet; tol=1e-8) function backwards(transforms::Vector{Transform}, m::ConstraintSet; tol = 1e-8)
m2 = deepcopy(m) m2 = deepcopy(m)
backwards!(transforms, m2; tol) backwards!(transforms, m2; tol)
return m2 return m2
@ -68,7 +68,7 @@ Base.@kwdef mutable struct MoveVarUpperBoundsToConstrs <: Transform end
function forward!(t::MoveVarUpperBoundsToConstrs, data::ProblemData) function forward!(t::MoveVarUpperBoundsToConstrs, data::ProblemData)
_, ncols = size(data.constr_lhs) _, ncols = size(data.constr_lhs)
data.constr_lhs = [data.constr_lhs; I] data.constr_lhs = [data.constr_lhs; I]
data.constr_lb = [data.constr_lb; [-Inf for _ in 1:ncols]] data.constr_lb = [data.constr_lb; [-Inf for _ = 1:ncols]]
data.constr_ub = [data.constr_ub; data.var_ub] data.constr_ub = [data.constr_ub; data.var_ub]
data.var_ub .= Inf data.var_ub .= Inf
end end
@ -98,9 +98,9 @@ end
function forward!(t::AddSlackVariables, data::ProblemData) function forward!(t::AddSlackVariables, data::ProblemData)
nrows, ncols = size(data.constr_lhs) nrows, ncols = size(data.constr_lhs)
isequality = abs.(data.constr_ub .- data.constr_lb) .< 1e-6 isequality = abs.(data.constr_ub .- data.constr_lb) .< 1e-6
eq = [i for i in 1:nrows if isequality[i]] eq = [i for i = 1:nrows if isequality[i]]
ge = [i for i in 1:nrows if isfinite(data.constr_lb[i]) && !isequality[i]] ge = [i for i = 1:nrows if isfinite(data.constr_lb[i]) && !isequality[i]]
le = [i for i in 1:nrows if isfinite(data.constr_ub[i]) && !isequality[i]] le = [i for i = 1:nrows if isfinite(data.constr_ub[i]) && !isequality[i]]
EQ, GE, LE = length(eq), length(ge), length(le) EQ, GE, LE = length(eq), length(ge), length(le)
t.M1 = [ t.M1 = [
@ -128,8 +128,8 @@ function forward!(t::AddSlackVariables, data::ProblemData)
data.obj = [data.obj; zeros(GE + LE)] data.obj = [data.obj; zeros(GE + LE)]
data.var_lb = [data.var_lb; zeros(GE + LE)] data.var_lb = [data.var_lb; zeros(GE + LE)]
data.var_ub = [data.var_ub; [Inf for _ = 1:(GE+LE)]] data.var_ub = [data.var_ub; [Inf for _ = 1:(GE+LE)]]
data.var_names = [data.var_names; ["__s$i" for i in 1:(GE+LE)]] data.var_names = [data.var_names; ["__s$i" for i = 1:(GE+LE)]]
data.var_types = [data.var_types; ['C' for _ in 1:(GE+LE)]] data.var_types = [data.var_types; ['C' for _ = 1:(GE+LE)]]
data.constr_lb = [ data.constr_lb = [
data.constr_lb[eq] data.constr_lb[eq]
data.constr_lb[ge] data.constr_lb[ge]
@ -157,15 +157,15 @@ end
Base.@kwdef mutable struct SplitFreeVars <: Transform Base.@kwdef mutable struct SplitFreeVars <: Transform
F::Int = 0 F::Int = 0
B::Int = 0 B::Int = 0
free::Vector{Int}=[] free::Vector{Int} = []
others::Vector{Int}=[] others::Vector{Int} = []
end end
function forward!(t::SplitFreeVars, data::ProblemData) function forward!(t::SplitFreeVars, data::ProblemData)
lhs = data.constr_lhs lhs = data.constr_lhs
_, ncols = size(lhs) _, ncols = size(lhs)
free = [i for i in 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])] free = [i for i = 1:ncols if !isfinite(data.var_lb[i]) && !isfinite(data.var_ub[i])]
others = [i for i in 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])] others = [i for i = 1:ncols if isfinite(data.var_lb[i]) || isfinite(data.var_ub[i])]
t.F = length(free) t.F = length(free)
t.B = length(others) t.B = length(others)
t.free, t.others = free, others t.free, t.others = free, others
@ -200,7 +200,7 @@ end
function backwards!(t::SplitFreeVars, c::ConstraintSet) function backwards!(t::SplitFreeVars, c::ConstraintSet)
# Convert GE constraints into LE # Convert GE constraints into LE
nrows, _ = size(c.lhs) nrows, _ = size(c.lhs)
ge = [i for i in 1:nrows if isfinite(c.lb[i])] ge = [i for i = 1:nrows if isfinite(c.lb[i])]
c.ub[ge], c.lb[ge] = -c.lb[ge], -c.ub[ge] c.ub[ge], c.lb[ge] = -c.lb[ge], -c.ub[ge]
c.lhs[ge, :] *= -1 c.lhs[ge, :] *= -1
@ -209,8 +209,8 @@ function backwards!(t::SplitFreeVars, c::ConstraintSet)
# Take minimum (weakest) coefficient # Take minimum (weakest) coefficient
B, F = t.B, t.F B, F = t.B, t.F
for i in 1:F for i = 1:F
c.lhs[:, B + i] = min.(c.lhs[:, B + i], -c.lhs[:, B + F + i]) c.lhs[:, B+i] = min.(c.lhs[:, B+i], -c.lhs[:, B+F+i])
end end
c.lhs = c.lhs[:, 1:(B+F)] c.lhs = c.lhs[:, 1:(B+F)]
end end
@ -231,7 +231,7 @@ end
function forward!(t::FlipUnboundedBelowVars, data::ProblemData) function forward!(t::FlipUnboundedBelowVars, data::ProblemData)
_, ncols = size(data.constr_lhs) _, ncols = size(data.constr_lhs)
for i in 1:ncols for i = 1:ncols
if isfinite(data.var_lb[i]) if isfinite(data.var_lb[i])
continue continue
end end

@ -2,6 +2,9 @@
# Copyright (C) 2020-2023, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2023, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
using Printf
using JLD2
global H5File = PyNULL() global H5File = PyNULL()
global write_pkl_gz = PyNULL() global write_pkl_gz = PyNULL()
global read_pkl_gz = PyNULL() global read_pkl_gz = PyNULL()
@ -36,8 +39,27 @@ end
function PyObject(m::SparseMatrixCSC) function PyObject(m::SparseMatrixCSC)
pyimport("scipy.sparse").csc_matrix( pyimport("scipy.sparse").csc_matrix(
(m.nzval, m.rowval .- 1, m.colptr .- 1), (m.nzval, m.rowval .- 1, m.colptr .- 1),
shape = size(m), shape=size(m),
).tocoo() ).tocoo()
end end
export H5File, write_pkl_gz, read_pkl_gz function write_jld2(
objs::Vector,
dirname::AbstractString;
prefix::AbstractString=""
)::Vector{String}
mkpath(dirname)
filenames = [@sprintf("%s/%s%05d.jld2", dirname, prefix, i) for i = 1:length(objs)]
for (i, obj) in enumerate(objs)
jldsave(filenames[i]; obj)
end
return filenames
end
function read_jld2(filename::AbstractString)::Any
jldopen(filename, "r") do file
return file["obj"]
end
end
export H5File, write_pkl_gz, read_pkl_gz, write_jld2, read_jld2

@ -278,8 +278,10 @@ function _set_warm_starts(model::JuMP.Model, var_names, var_values, stats)
n_starts == 1 || error("JuMP does not support multiple warm starts") n_starts == 1 || error("JuMP does not support multiple warm starts")
vars = [variable_by_name(model, v) for v in var_names] vars = [variable_by_name(model, v) for v in var_names]
for (i, var) in enumerate(vars) for (i, var) in enumerate(vars)
if isfinite(var_values[i])
set_start_value(var, var_values[i]) set_start_value(var, var_values[i])
end end
end
end end
function _write(model::JuMP.Model, filename) function _write(model::JuMP.Model, filename)

@ -8,6 +8,7 @@ Clp = "e2554f3b-3117-50c0-817c-e040a3ddf72d"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b" HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572" JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@ -15,3 +16,7 @@ MIPLearn = "2b1277c3-b477-4c49-a15e-7ba350325c68"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[compat]
JuMP = "1"
julia = "1"

Binary file not shown.

@ -3,9 +3,16 @@
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
using MIPLearn using MIPLearn
using JLD2
struct _TestStruct
n::Int
q::Vector{Float64}
end
function test_io() function test_io()
test_pkl_gz() test_pkl_gz()
test_jld2()
test_h5() test_h5()
end end
@ -32,6 +39,26 @@ function test_h5()
h5.close() h5.close()
end end
function test_jld2()
dirname = mktempdir()
data = [
_TestStruct(1, [0.0, 0.0, 0.0]),
_TestStruct(2, [1.0, 2.0, 3.0]),
_TestStruct(3, [3.0, 3.0, 3.0]),
]
filenames = write_jld2(data, dirname, prefix="obj")
@test all(
filenames .==
["$dirname/obj00001.jld2", "$dirname/obj00002.jld2", "$dirname/obj00003.jld2"],
)
@assert isfile("$dirname/obj00001.jld2")
@assert isfile("$dirname/obj00002.jld2")
@assert isfile("$dirname/obj00003.jld2")
recovered = read_jld2("$dirname/obj00002.jld2")
@test recovered.n == 2
@test all(recovered.q .== [1.0, 2.0, 3.0])
end
function _test_roundtrip_scalar(h5, original) function _test_roundtrip_scalar(h5, original)
h5.put_scalar("key", original) h5.put_scalar("key", original)
recovered = h5.get_scalar("key") recovered = h5.get_scalar("key")

Loading…
Cancel
Save