Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stats to ensemble solution #512

Merged
merged 14 commits into from
Oct 28, 2023
21 changes: 18 additions & 3 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
"""
struct EnsembleSerial <: BasicEnsembleAlgorithm end

function merge_stats(us)
st = Iterators.filter(!isnothing, (x.stats for x in us))
try
return reduce(merge, st)

Check warning on line 29 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L26-L29

Added lines #L26 - L29 were not covered by tests
catch e
if isa(e, MethodError)

Check warning on line 31 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L31

Added line #L31 was not covered by tests
# there were no stats or they didn't have a merge method
return nothing

Check warning on line 33 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L33

Added line #L33 was not covered by tests
else
rethrow(e)

Check warning on line 35 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L35

Added line #L35 was not covered by tests
end
end
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing};
kwargs...)
Expand Down Expand Up @@ -64,7 +78,8 @@
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
return EnsembleSolution(_u, elapsed_time, true)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)

Check warning on line 82 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
end

converged::Bool = false
Expand All @@ -88,8 +103,8 @@
end
end
_u = tighten_container_eltype(u)

return EnsembleSolution(_u, elapsed_time, converged)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)

Check warning on line 107 in src/ensemble/basic_ensemble_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/basic_ensemble_solve.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
end

function batch_func(i, prob, alg; kwargs...)
Expand Down
20 changes: 11 additions & 9 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,27 @@
"""
$(TYPEDEF)
"""
struct EnsembleSolution{T, N, S} <: AbstractEnsembleSolution{T, N, S}
struct EnsembleSolution{T, N, S, U} <: AbstractEnsembleSolution{T, N, S}
u::S
elapsedTime::Float64
converged::Bool
stats::U
end
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged) where {N}
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged)
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged, stats) where {N}
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim), typeof(stats)}(sim, elapsedTime, converged, stats)

Check warning on line 35 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
end
function EnsembleSolution(sim, elapsedTime, converged)
EnsembleSolution(sim, (length(sim),), elapsedTime, converged)
function EnsembleSolution(sim, elapsedTime, converged, stats)
EnsembleSolution(sim, (length(sim),), elapsedTime, converged, stats)

Check warning on line 38 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L37-L38

Added lines #L37 - L38 were not covered by tests
end # Vector of some type which is not an array
function EnsembleSolution(sim::T, elapsedTime,
converged) where {T <: AbstractVector{T2}
converged, stats) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,
typeof(sim)}(sim,
typeof(sim), typeof(stats)}(sim,
elapsedTime,
converged)
converged,
stats)
end

struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
Expand All @@ -56,7 +58,7 @@
end

function Base.reverse(sim::EnsembleSolution)
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged)
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged, sim.stats)

Check warning on line 61 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L61

Added line #L61 was not covered by tests
end

"""
Expand Down
14 changes: 14 additions & 0 deletions src/solutions/nonlinear_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
nsteps::Int
end

function Base.show(io::IO, ::MIME"text/plain", s::NLStats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of factorizations:" s.nfactors
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d" "Number of nonlinear solver iterations:" s.nsteps

Check warning on line 26 in src/solutions/nonlinear_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/nonlinear_solutions.jl#L20-L26

Added lines #L20 - L26 were not covered by tests
end

function Base.merge(s1::NLStats, s2::NLStats)
NLStats(s1.nf + s2.nf, s1.njacs + s2.njacs, s1.nfactors + s2.nfactors,

Check warning on line 30 in src/solutions/nonlinear_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/nonlinear_solutions.jl#L29-L30

Added lines #L29 - L30 were not covered by tests
s1.nsolve + s2.nsolve, s1.nsteps + s2.nsteps)
end

"""
$(TYPEDEF)

Expand Down
73 changes: 73 additions & 0 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,79 @@
"""
$(TYPEDEF)

Statistics from the differential equation solver about the solution process.

## Fields

- nf: Number of function evaluations. If the differential equation is a split function,
such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the
number of function evaluations for the first function (the implicit function)
- nf2: If the differential equation is a split function, such as a `SplitFunction`
for implicit-explicit (IMEX) integration, then `nf2` is the number of function
evaluations for the second function, i.e. the function treated explicitly. Otherwise
it is zero.
- nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving
process.
- nsolve: The number of linear solves `W\b` required for the integration.
- njacs: Number of Jacobians calculated during the integration.
- nnonliniter: Total number of iterations for the nonlinear solvers.
- nnonlinconvfail: Number of nonlinear solver convergence failures.
- ncondition: Number of calls to the condition function for callbacks.
- naccept: Number of accepted steps.
- nreject: Number of rejected steps.
- maxeig: Maximum eigenvalue over the solution. This is only computed if the
method is an auto-switching algorithm.
"""
mutable struct DEStats
nf::Int
nf2::Int
nw::Int
nsolve::Int
njacs::Int
nnonliniter::Int
nnonlinconvfail::Int
ncondition::Int
naccept::Int
nreject::Int
maxeig::Float64
end

DEStats(x::Int = -1) = DEStats(x, x, x, x, x, x, x, x, x, x, 0.0)

Check warning on line 41 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L41

Added line #L41 was not covered by tests

function Base.show(io::IO, ::MIME"text/plain", s::DEStats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2
@printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter
@printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail
@printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition
@printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept
@printf io "%-50s %-d" "Number of rejected steps:" s.nreject
iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig

Check warning on line 55 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L43-L55

Added lines #L43 - L55 were not covered by tests
end

function Base.merge(a::DEStats, b::DEStats)
DEStats(

Check warning on line 59 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L58-L59

Added lines #L58 - L59 were not covered by tests
a.nf + b.nf,
a.nf2 + b.nf2,
a.nw + b.nw,
a.nsolve + b.nsolve,
a.njacs + b.njacs,
a.nnonliniter + b.nnonliniter,
a.nnonlinconvfail + b.nnonlinconvfail,
a.ncondition + b.ncondition,
a.naccept + b.naccept,
a.nreject + b.nreject,
max(a.maxeig, b.maxeig),
)
end

"""
$(TYPEDEF)

Representation of the solution to an ordinary differential equation defined by an ODEProblem.

## DESolution Interface
Expand Down
13 changes: 13 additions & 0 deletions test/downstream/ensemble_stats.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using OrdinaryDiffEq
using Test

f(u,p,t) = 1.01*u
u0=1/2
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
@test sim.stats.nf == mapreduce(x -> x.stats.nf, +, sim.u)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ end
@time @safetestset "solving Ensembles with multiple problems" begin
include("downstream/ensemble_multi_prob.jl")
end
@time @safetestset "Ensemble solution statistics" begin
include("downstream/ensemble_stats.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
Expand Down
Loading