Skip to content

Commit

Permalink
progress bars for EnsembleProblem
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Sep 29, 2023
1 parent 5b127c5 commit 63b1299
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Logging
"""
$(TYPEDEF)
"""
Expand Down Expand Up @@ -60,6 +61,13 @@ function __solve(prob::AbstractEnsembleProblem,
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

if get(kwargs, :progress, false)
name = get(kwargs, :progress_name, "ODE")
for i in 1:trajectories
@logmsg(LogLevel(-1), "$name #$i", _id=Symbol("SciMLBase_$i"), progress=0)
end

Check warning on line 68 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L64-L68

Added lines #L64 - L68 were not covered by tests
end

if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
Expand Down Expand Up @@ -97,7 +105,12 @@ function batch_func(i, prob, alg; kwargs...)
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
rerun = true
x = prob.output_func(solve(new_prob, alg; kwargs...), i)

name = get(kwargs, :progress_name, "ODE")
progress_name = "$name #$i"
progress_id = Symbol("SciMLBase_$i")

Check warning on line 111 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L109-L111

Added lines #L109 - L111 were not covered by tests

x = prob.output_func(solve(new_prob, alg; progress_name, progress_id, kwargs...), i)

Check warning on line 113 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L113

Added line #L113 was not covered by tests
if !(typeof(x) <: Tuple)
rerun_warn()
_x = (x, false)
Expand All @@ -109,7 +122,7 @@ function batch_func(i, prob, alg; kwargs...)
iter += 1
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
x = prob.output_func(solve(new_prob, alg; progress_name, progress_id, kwargs...), i)

Check warning on line 125 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L125

Added line #L125 was not covered by tests
if !(typeof(x) <: Tuple)
rerun_warn()
_x = (x, false)
Expand Down

0 comments on commit 63b1299

Please sign in to comment.