Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

progress bars for EnsembleProblem #514

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand All @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote"
[compat]
ADTypes = "0.1.3, 0.2"
ArrayInterface = "6, 7"
ChainRules = "1.57.0"
ChainRulesCore = "1.16"
CommonSolve = "0.2.4"
ConstructionBase = "1"
Expand Down
139 changes: 109 additions & 30 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,64 @@ function merge_stats(us)
reduce(merge, st)
end

mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger
progress::Dict{Symbol, Float64}
done_counter::Int
total::Float64
print_time::Float64
lock::ReentrantLock
logger::T
end
AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger)

function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...)
if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress)
pr = kwargs[:progress]
if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing)
try
if pr == "done"
pr = 1.0
l.done_counter += 1
end
len = length(l.progress)
if haskey(l.progress, id)
l.total += (pr-l.progress[id])/len
else
l.total = l.total*(len/(len+1)) + pr/(len+1)
len += 1
end
l.progress[id] = pr
# validation check (slow)
# tot = sum(values(l.progress))/length(l.progress)
# @show tot l.total l.total ≈ tot
curr_time = time()
if l.done_counter >= len
tot="done"
empty!(l.progress)
l.done_counter = 0
l.print_time = 0.0
elseif curr_time-l.print_time > 0.1
tot = l.total
l.print_time = curr_time
else
return
end
id=:total
message="Total"
kwargs=merge(values(kwargs), (progress=tot,))
finally
unlock(l.lock)
end
else
return
end
end
Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...)
end
Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...)
Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger)
Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger)

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing};
kwargs...)
Expand Down Expand Up @@ -59,51 +117,72 @@ end
function __solve(prob::AbstractEnsembleProblem,
alg::A,
ensemblealg::BasicEnsembleAlgorithm;
trajectories, batch_size = trajectories,
trajectories, batch_size = trajectories, progress_aggregate=true,
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end
logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger()

Logging.with_logger(logger) do
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

converged::Bool = false
elapsed_time = @elapsed begin
i = 1
II = (batch_size * (i - 1) + 1):(batch_size * i)
if get(kwargs, :progress, false)
name = get(kwargs, :progress_name, "Ensemble")
for i in 1:trajectories
msg = "$name #$i"
Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0)
end
end


batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end

converged::Bool = false
elapsed_time = @elapsed begin
i = 1
II = (batch_size * (i - 1) + 1):(batch_size * i)

u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
u, converged = prob.reduction(u, batch_data, II)
for i in 2:num_batches
converged && break
if i == num_batches
II = (batch_size * (i - 1) + 1):trajectories
else
II = (batch_size * (i - 1) + 1):(batch_size * i)
end
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)

u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
u, converged = prob.reduction(u, batch_data, II)
for i in 2:num_batches
converged && break
if i == num_batches
II = (batch_size * (i - 1) + 1):trajectories
else
II = (batch_size * (i - 1) + 1):(batch_size * i)
end
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
u, converged = prob.reduction(u, batch_data, II)
end
end
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end

function batch_func(i, prob, alg; kwargs...)
iter = 1
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
rerun = true

progress = get(kwargs, :progress, false)
if progress
name = get(kwargs, :progress_name, "Ensemble")
progress_name = "$name #$i"
progress_id = Symbol("SciMLBase_$i")
kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id)
end
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(x isa Tuple)
rerun_warn()
Expand Down