From 1ecae9342f3264529e7b131063f00a985ae00f73 Mon Sep 17 00:00:00 2001 From: Pepijn de Vos Date: Fri, 29 Sep 2023 14:44:56 +0200 Subject: [PATCH 1/3] progress bars for EnsembleProblem ODE->Ensemble optionally aggregate progress bars handle integer loglevel (sundials) avoid lock contention only log significant progress improve progress performance add version constraint Update Project.toml Update Project.toml ignore derivatives of logging more AD fixing attempts only pass progress_id if needed use ignore_deriviative and fix rrule for with_logger delete rules moved to ChainRules remove using import as opposed to using more import fixes more missing Logging qualifiers --- src/ensemble/basic_ensemble_solve.jl | 138 +++++++++++++++++++++------ 1 file changed, 108 insertions(+), 30 deletions(-) diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index db08eb9a8..97cfa88b9 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -29,6 +29,64 @@ function merge_stats(us) reduce(merge, st) end +mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger + progress::Dict{Symbol, Float64} + done_counter::Int + total::Float64 + print_time::Float64 + lock::ReentrantLock + logger::T +end +AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger) + +function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...) + if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress) + pr = kwargs[:progress] + if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing) + try + if pr == "done" + pr = 1.0 + l.done_counter += 1 + end + len = length(l.progress) + if haskey(l.progress, id) + l.total += (pr-l.progress[id])/len + else + l.total = l.total*(len/(len+1)) + pr/(len+1) + len += 1 + end + l.progress[id] = pr + # validation check (slow) + # tot = sum(values(l.progress))/length(l.progress) + # @show tot l.total l.total ≈ tot + curr_time = time() + if l.done_counter >= len + tot="done" + empty!(l.progress) + l.done_counter = 0 + l.print_time = 0.0 + elseif curr_time-l.print_time > 0.1 + tot = l.total + l.print_time = curr_time + else + return + end + id=:total + message="Total" + kwargs=merge(values(kwargs), (progress=tot,)) + finally + unlock(l.lock) + end + else + return + end + end + Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...) +end +Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...) +Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger) +Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger) + function __solve(prob::AbstractEnsembleProblem, alg::Union{AbstractDEAlgorithm, Nothing}; kwargs...) @@ -59,44 +117,56 @@ end function __solve(prob::AbstractEnsembleProblem, alg::A, ensemblealg::BasicEnsembleAlgorithm; - trajectories, batch_size = trajectories, + trajectories, batch_size = trajectories, progress_aggregate=true, pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A} - num_batches = trajectories ÷ batch_size - num_batches < 1 && - error("trajectories ÷ batch_size cannot be less than 1, got $num_batches") - num_batches * batch_size != trajectories && (num_batches += 1) - - if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION - elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories, - pmap_batch_size; kwargs...) - _u = tighten_container_eltype(u) - stats = merge_stats(_u) - return EnsembleSolution(_u, elapsed_time, true, stats) - end + logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger() + + Logging.with_logger(logger) do + num_batches = trajectories ÷ batch_size + num_batches < 1 && + error("trajectories ÷ batch_size cannot be less than 1, got $num_batches") + num_batches * batch_size != trajectories && (num_batches += 1) - converged::Bool = false - elapsed_time = @elapsed begin - i = 1 - II = (batch_size * (i - 1) + 1):(batch_size * i) + if get(kwargs, :progress, false) + name = get(kwargs, :progress_name, "Ensemble") + for i in 1:trajectories + msg = "$name #$i" + Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0) + end + end + - batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...) + if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION + elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories, + pmap_batch_size; kwargs...) + _u = tighten_container_eltype(u) + return EnsembleSolution(_u, elapsed_time, true) + end + + converged::Bool = false + elapsed_time = @elapsed begin + i = 1 + II = (batch_size * (i - 1) + 1):(batch_size * i) - u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init - u, converged = prob.reduction(u, batch_data, II) - for i in 2:num_batches - converged && break - if i == num_batches - II = (batch_size * (i - 1) + 1):trajectories - else - II = (batch_size * (i - 1) + 1):(batch_size * i) - end batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...) + + u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init u, converged = prob.reduction(u, batch_data, II) + for i in 2:num_batches + converged && break + if i == num_batches + II = (batch_size * (i - 1) + 1):trajectories + else + II = (batch_size * (i - 1) + 1):(batch_size * i) + end + batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...) + u, converged = prob.reduction(u, batch_data, II) + end end + _u = tighten_container_eltype(u) + stats = merge_stats(_u) + return EnsembleSolution(_u, elapsed_time, converged, stats) end - _u = tighten_container_eltype(u) - stats = merge_stats(_u) - return EnsembleSolution(_u, elapsed_time, converged, stats) end function batch_func(i, prob, alg; kwargs...) @@ -104,6 +174,14 @@ function batch_func(i, prob, alg; kwargs...) _prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob new_prob = prob.prob_func(_prob, i, iter) rerun = true + + progress = get(kwargs, :progress, false) + if progress + name = get(kwargs, :progress_name, "Ensemble") + progress_name = "$name #$i" + progress_id = Symbol("SciMLBase_$i") + kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id) + end x = prob.output_func(solve(new_prob, alg; kwargs...), i) if !(x isa Tuple) rerun_warn() From 1d44a6692fe4179ffc35b0b168a99c4059032345 Mon Sep 17 00:00:00 2001 From: Pepijn de Vos Date: Tue, 7 Nov 2023 14:41:32 +0100 Subject: [PATCH 2/3] add chainrules version requirement --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 6d1becd12..8d9b53c08 100644 --- a/Project.toml +++ b/Project.toml @@ -33,6 +33,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote" [compat] ADTypes = "0.1.3, 0.2" ArrayInterface = "6, 7" +ChainRules = "1.57.0" ChainRulesCore = "1.16" CommonSolve = "0.2.4" ConstructionBase = "1" From a1370a0bf84d1f5ec1e655b2f0b37228f423773d Mon Sep 17 00:00:00 2001 From: Pepijn de Vos Date: Wed, 8 Nov 2023 12:55:22 +0100 Subject: [PATCH 3/3] put stats back This most likely got lost in a rebase --- src/ensemble/basic_ensemble_solve.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 97cfa88b9..ece01eb68 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -140,7 +140,8 @@ function __solve(prob::AbstractEnsembleProblem, elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories, pmap_batch_size; kwargs...) _u = tighten_container_eltype(u) - return EnsembleSolution(_u, elapsed_time, true) + stats = merge_stats(_u) + return EnsembleSolution(_u, elapsed_time, true, stats) end converged::Bool = false