Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stats to ensemble solution #512

Merged
merged 14 commits into from
Oct 28, 2023
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
24 changes: 23 additions & 1 deletion ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

using Zygote: pullback
using ZygoteRules: @adjoint
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
import ZygoteRules
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -55,4 +56,25 @@
VA[sym, j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function EnsembleSolution(sim, time, converged, stats)
out = EnsembleSolution(sim, time, converged)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]

Check warning on line 62 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L59-L62

Added lines #L59 - L62 were not covered by tests
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing, nothing)

Check warning on line 64 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L64

Added line #L64 was not covered by tests
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true), nothing, nothing, nothing)

Check warning on line 67 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L66-L67

Added lines #L66 - L67 were not covered by tests
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing, nothing)

Check warning on line 70 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end
out, EnsembleSolution_adjoint

Check warning on line 72 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L72

Added line #L72 was not covered by tests
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,

Check warning on line 75 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L75

Added line #L75 was not covered by tests
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)

Check warning on line 77 in ext/SciMLBaseZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/SciMLBaseZygoteExt.jl#L77

Added line #L77 was not covered by tests
end

end
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using LinearAlgebra
using Statistics
using Distributed
using Markdown
using Printf
import Preferences

import Logging, ArrayInterface
Expand Down
13 changes: 10 additions & 3 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ $(TYPEDEF)
"""
struct EnsembleSerial <: BasicEnsembleAlgorithm end

function merge_stats(us)
st = Iterators.filter(!isnothing, (hasproperty(x, :stats) ? x.stats : nothing for x in us))
isempty(st) && return nothing
reduce(merge, st)
end

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing};
kwargs...)
Expand Down Expand Up @@ -64,7 +70,8 @@ function __solve(prob::AbstractEnsembleProblem,
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
return EnsembleSolution(_u, elapsed_time, true)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end

converged::Bool = false
Expand All @@ -88,8 +95,8 @@ function __solve(prob::AbstractEnsembleProblem,
end
end
_u = tighten_container_eltype(u)

return EnsembleSolution(_u, elapsed_time, converged)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end

function batch_func(i, prob, alg; kwargs...)
Expand Down
20 changes: 11 additions & 9 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,23 @@
u::S
elapsedTime::Float64
converged::Bool
stats
end
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged) where {N}
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged)
function EnsembleSolution(sim, dims::NTuple{N}, elapsedTime, converged, stats) where {N}
EnsembleSolution{eltype(eltype(sim)), N, typeof(sim)}(sim, elapsedTime, converged, stats)
end
function EnsembleSolution(sim, elapsedTime, converged)
EnsembleSolution(sim, (length(sim),), elapsedTime, converged)
function EnsembleSolution(sim, elapsedTime, converged, stats=nothing)
EnsembleSolution(sim, (length(sim),), elapsedTime, converged, stats)
end # Vector of some type which is not an array
function EnsembleSolution(sim::T, elapsedTime,
converged) where {T <: AbstractVector{T2}
converged, stats=nothing) where {T <: AbstractVector{T2}
} where {T2 <:
AbstractArray}
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1,
typeof(sim)}(sim,
EnsembleSolution{eltype(eltype(sim)), ndims(sim[1]) + 1, typeof(sim)}(
sim,
elapsedTime,
converged)
converged,
stats)
end

struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
Expand All @@ -56,7 +58,7 @@
end

function Base.reverse(sim::EnsembleSolution)
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged)
EnsembleSolution(reverse(sim.u), sim.elapsedTime, sim.converged, sim.stats)

Check warning on line 61 in src/ensemble/ensemble_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/ensemble/ensemble_solutions.jl#L61

Added line #L61 was not covered by tests
end

"""
Expand Down
14 changes: 14 additions & 0 deletions src/solutions/nonlinear_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
nsteps::Int
end

function Base.show(io::IO, ::MIME"text/plain", s::NLStats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of factorizations:" s.nfactors
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d" "Number of nonlinear solver iterations:" s.nsteps

Check warning on line 26 in src/solutions/nonlinear_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/nonlinear_solutions.jl#L20-L26

Added lines #L20 - L26 were not covered by tests
end

function Base.merge(s1::NLStats, s2::NLStats)
NLStats(s1.nf + s2.nf, s1.njacs + s2.njacs, s1.nfactors + s2.nfactors,

Check warning on line 30 in src/solutions/nonlinear_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/nonlinear_solutions.jl#L29-L30

Added lines #L29 - L30 were not covered by tests
s1.nsolve + s2.nsolve, s1.nsteps + s2.nsteps)
end

"""
$(TYPEDEF)

Expand Down
73 changes: 73 additions & 0 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,79 @@
"""
$(TYPEDEF)

Statistics from the differential equation solver about the solution process.

## Fields

- nf: Number of function evaluations. If the differential equation is a split function,
such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the
number of function evaluations for the first function (the implicit function)
- nf2: If the differential equation is a split function, such as a `SplitFunction`
for implicit-explicit (IMEX) integration, then `nf2` is the number of function
evaluations for the second function, i.e. the function treated explicitly. Otherwise
it is zero.
- nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving
process.
- nsolve: The number of linear solves `W\b` required for the integration.
- njacs: Number of Jacobians calculated during the integration.
- nnonliniter: Total number of iterations for the nonlinear solvers.
- nnonlinconvfail: Number of nonlinear solver convergence failures.
- ncondition: Number of calls to the condition function for callbacks.
- naccept: Number of accepted steps.
- nreject: Number of rejected steps.
- maxeig: Maximum eigenvalue over the solution. This is only computed if the
method is an auto-switching algorithm.
"""
mutable struct DEStats
nf::Int
nf2::Int
nw::Int
nsolve::Int
njacs::Int
nnonliniter::Int
nnonlinconvfail::Int
ncondition::Int
naccept::Int
nreject::Int
maxeig::Float64
end

DEStats(x::Int = -1) = DEStats(x, x, x, x, x, x, x, x, x, x, 0.0)

function Base.show(io::IO, ::MIME"text/plain", s::DEStats)
println(io, summary(s))
@printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf
@printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2
@printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw
@printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve
@printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs
@printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter
@printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail
@printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition
@printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept
@printf io "%-50s %-d" "Number of rejected steps:" s.nreject
iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig

Check warning on line 55 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L43-L55

Added lines #L43 - L55 were not covered by tests
end

function Base.merge(a::DEStats, b::DEStats)
DEStats(
a.nf + b.nf,
a.nf2 + b.nf2,
a.nw + b.nw,
a.nsolve + b.nsolve,
a.njacs + b.njacs,
a.nnonliniter + b.nnonliniter,
a.nnonlinconvfail + b.nnonlinconvfail,
a.ncondition + b.ncondition,
a.naccept + b.naccept,
a.nreject + b.nreject,
max(a.maxeig, b.maxeig),
)
end

"""
$(TYPEDEF)

Representation of the solution to an ordinary differential equation defined by an ODEProblem.

## DESolution Interface
Expand Down
13 changes: 13 additions & 0 deletions test/downstream/ensemble_stats.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using OrdinaryDiffEq
using Test

f(u,p,t) = 1.01*u
u0=1/2
tspan = (0.0,1.0)
prob = ODEProblem(f,u0,tspan)
function prob_func(prob, i, repeat)
remake(prob, u0 = rand() * prob.u0)
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
@test sim.stats.nf == mapreduce(x -> x.stats.nf, +, sim.u)
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ end
@time @safetestset "solving Ensembles with multiple problems" begin
include("downstream/ensemble_multi_prob.jl")
end
@time @safetestset "Ensemble solution statistics" begin
include("downstream/ensemble_stats.jl")
end
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
include("downstream/symbol_indexing.jl")
end
Expand Down
Loading