mirror of
https://github.com/ANL-CEEESA/MIPLearn.jl.git
synced 2025-12-06 16:38:51 -06:00
FileInstance: Make interface simpler to use
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
# Released under the modified BSD license. See COPYING.md for more details.
|
||||
|
||||
using JLD2
|
||||
using Distributed
|
||||
using ProgressBars
|
||||
import Base: flush
|
||||
|
||||
mutable struct FileInstance <: Instance
|
||||
@@ -20,8 +22,9 @@ mutable struct FileInstance <: Instance
|
||||
)::FileInstance
|
||||
instance = new(nothing, nothing, filename, nothing, build_model, mode)
|
||||
instance.py = PyFileInstance(instance)
|
||||
if mode != "r" || isfile("$filename.h5")
|
||||
instance.sample = Hdf5Sample("$filename.h5", mode = mode)
|
||||
h5_filename = replace(filename, ".jld2" => ".h5")
|
||||
if mode != "r" || isfile(h5_filename)
|
||||
instance.sample = Hdf5Sample(h5_filename, mode = mode)
|
||||
end
|
||||
instance.filename = filename
|
||||
return instance
|
||||
@@ -107,23 +110,15 @@ function load(filename::AbstractString, build_model::Function)
|
||||
end
|
||||
end
|
||||
|
||||
function save(data::AbstractVector, dirname::String)::Nothing
|
||||
function save(data::AbstractVector, dirname::String)::Vector{String}
|
||||
mkpath(dirname)
|
||||
filenames = []
|
||||
for (i, d) in enumerate(data)
|
||||
filename = joinpath(dirname, @sprintf("%06d.jld2", i))
|
||||
push!(filenames, filename)
|
||||
jldsave(filename, data = d)
|
||||
end
|
||||
end
|
||||
|
||||
function solve!(
|
||||
solver::LearningSolver,
|
||||
filenames::Vector,
|
||||
build_model::Function;
|
||||
tee::Bool = false,
|
||||
)
|
||||
for filename in filenames
|
||||
solve!(solver, filename, build_model; tee)
|
||||
end
|
||||
return filenames
|
||||
end
|
||||
|
||||
function fit!(
|
||||
@@ -136,13 +131,43 @@ function fit!(
|
||||
fit!(solver, instances)
|
||||
end
|
||||
|
||||
function solve!(
|
||||
solver::LearningSolver,
|
||||
filenames::Vector,
|
||||
build_model::Function;
|
||||
tee::Bool = false,
|
||||
progress::Bool = false,
|
||||
)
|
||||
if progress
|
||||
filenames = ProgressBar(filenames)
|
||||
end
|
||||
return [solve!(solver, f, build_model; tee) for f in filenames]
|
||||
end
|
||||
|
||||
function solve!(
|
||||
solver::LearningSolver,
|
||||
filename::AbstractString,
|
||||
build_model::Function;
|
||||
tee::Bool = false,
|
||||
)
|
||||
solve!(solver, FileInstance(filename, build_model); tee)
|
||||
instance = FileInstance(filename, build_model)
|
||||
stats = solve!(solver, instance; tee)
|
||||
instance.sample.file.close()
|
||||
return stats
|
||||
end
|
||||
|
||||
function parallel_solve!(
|
||||
solver::LearningSolver,
|
||||
filenames::Vector,
|
||||
build_model::Function;
|
||||
tee::Bool = false,
|
||||
)
|
||||
solver_filename = tempname()
|
||||
save(solver_filename, solver)
|
||||
@sync @distributed for filename in filenames
|
||||
local_solver = load_solver(solver_filename)
|
||||
solve!(local_solver, filename, build_model; tee)
|
||||
end
|
||||
end
|
||||
|
||||
function __init_PyFileInstance__()
|
||||
|
||||
@@ -4,57 +4,60 @@
|
||||
|
||||
using CSV
|
||||
using DataFrames
|
||||
using OrderedCollections
|
||||
|
||||
function run_benchmarks(;
|
||||
optimizer,
|
||||
train_instances::Vector{<:AbstractString},
|
||||
test_instances::Vector{<:AbstractString},
|
||||
build_model::Function,
|
||||
progress::Bool = false,
|
||||
output_filename::String,
|
||||
)
|
||||
solvers = OrderedDict(
|
||||
"baseline" => LearningSolver(optimizer),
|
||||
"ml-exact" => LearningSolver(optimizer),
|
||||
"ml-heuristic" => LearningSolver(optimizer, mode="heuristic"),
|
||||
)
|
||||
|
||||
mutable struct BenchmarkRunner
|
||||
solvers::Dict
|
||||
results::Union{Nothing,DataFrame}
|
||||
py::PyCall.PyObject
|
||||
#solve!(solvers["baseline"], train_instances, build_model; progress)
|
||||
fit!(solvers["ml-exact"], train_instances, build_model)
|
||||
fit!(solvers["ml-heuristic"], train_instances, build_model)
|
||||
|
||||
function BenchmarkRunner(; solvers::Dict)
|
||||
return new(
|
||||
solvers,
|
||||
nothing, # results
|
||||
miplearn.BenchmarkRunner(
|
||||
Dict(sname => solver.py for (sname, solver) in solvers),
|
||||
),
|
||||
)
|
||||
stats = OrderedDict()
|
||||
for (solver_name, solver) in solvers
|
||||
stats[solver_name] = solve!(solver, test_instances, build_model; progress)
|
||||
end
|
||||
end
|
||||
|
||||
function solve!(
|
||||
runner::BenchmarkRunner,
|
||||
instances::Vector{FileInstance};
|
||||
n_trials::Int = 1,
|
||||
)::Nothing
|
||||
instances = repeat(instances, n_trials)
|
||||
for (solver_name, solver) in runner.solvers
|
||||
@info "benchmark $solver_name"
|
||||
stats = [
|
||||
solve!(solver, instance, discard_output = true, tee = true) for
|
||||
instance in instances
|
||||
]
|
||||
for (i, s) in enumerate(stats)
|
||||
results = nothing
|
||||
for (solver_name, solver_stats) in stats
|
||||
for (i, s) in enumerate(solver_stats)
|
||||
s["Solver"] = solver_name
|
||||
s["Instance"] = instances[i].filename
|
||||
s["Instance"] = test_instances[i]
|
||||
s = Dict(k => isnothing(v) ? missing : v for (k, v) in s)
|
||||
if runner.results === nothing
|
||||
runner.results = DataFrame(s)
|
||||
if results === nothing
|
||||
results = DataFrame(s)
|
||||
else
|
||||
push!(runner.results, s, cols = :union)
|
||||
push!(results, s, cols = :union)
|
||||
end
|
||||
end
|
||||
@info "benchmark $solver_name [done]"
|
||||
end
|
||||
end
|
||||
|
||||
function fit!(runner::BenchmarkRunner, instances::Vector{FileInstance})::Nothing
|
||||
@python_call runner.py.fit([instance.py for instance in instances])
|
||||
end
|
||||
|
||||
function write_csv!(runner::BenchmarkRunner, filename::AbstractString)::Nothing
|
||||
CSV.write(filename, runner.results)
|
||||
CSV.write(output_filename, results)
|
||||
|
||||
# fig_filename = "$(tempname()).svg"
|
||||
# df = pyimport("pandas").read_csv(csv_filename)
|
||||
# miplearn.benchmark.plot(df, output=fig_filename)
|
||||
# open(fig_filename) do f
|
||||
# display("image/svg+xml", read(f, String))
|
||||
# end
|
||||
return
|
||||
end
|
||||
|
||||
function run_benchmarks(; solvers, instance_filenames, build_model, output_filename)
|
||||
runner = BenchmarkRunner(; solvers)
|
||||
instances = [FileInstance(f, build_model) for f in instance_filenames]
|
||||
solve!(runner, instances)
|
||||
write_csv!(runner, output_filename)
|
||||
end
|
||||
|
||||
export BenchmarkRunner, solve!, fit!, write_csv!
|
||||
|
||||
Reference in New Issue
Block a user