Skip to content

Commit

Permalink
Refactor the plotting in favor of a Gaussian plot recipe (#291)
Browse files Browse the repository at this point in the history
* Refactor the plotting in favor of a Gaussian plot recipe

Previously we had a recipe for the specific solution type that we implemented.
But for some reason, this broke du to sanges of SciML. So to not have such
issues again, let's just use their plot recipe! To do this we just needed to
implement some recipes for the `Gaussian` type that we use in the solutions, and
make sure that the solution interpolation etc is doing what it should do.

* Towards using the `sol.interp` argument for interpolation

* Add PNStats object and make the solution immutable

* Respect the `idxs` interpolation argument to fix the tests

* Make clear that we don't support plotting derivatives right now

Though we might want to add this in this PR!!

* JuliaFormatter.jl

* Add proper types to the posterior object

* Fix the mean solution again

* Add Printf compat

* Fix an unexpected secondorderode bug

* Solve another classic solver init problem

* Simplify the prior plots yet again

* Remove inefficient `push!` loop in favor of `append!`

* Add the mean solution tests again

* Add docstring to PNStats

* Fix a bug

* The solution interpolation can return derivatives now!

* For some reason I needed this to fix the plotting

* Fix the plots once again

* Test the derivative interpolation

* Update the implementation docs a bit

* Add compat entry for DiffEqDevTools to the test env
  • Loading branch information
nathanaelbosch authored Feb 3, 2024
1 parent 815af08 commit 6935a52
Show file tree
Hide file tree
Showing 16 changed files with 202 additions and 118 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PSDMatrices = "fe68d972-6fd8-4755-bdf0-97d4c54cefdc"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -62,6 +63,7 @@ Octavian = "0.3.17"
OrdinaryDiffEq = "6.52"
PSDMatrices = "0.4.7"
PrecompileTools = "1"
Printf = "1"
RecipesBase = "1"
RecursiveArrayTools = "2, 3"
Reexport = "1"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ ProbNumDiffEq.jl builds around this structure and overloads some of the parts:
- `./src/solution.jl` implements the main parts.
Note that the main constructor `DiffEqBase.build_solution` is called by `OrdinaryDiffEq.__init`, so OrdinaryDiffEq.jl has control over its inputs.
- `MeanProbODESolution <: DiffEqBase.AbstractODESolution` is a wrapper that allows handling the mean of a probabilistic ODE solution the same way one would handle any "standard" ODE solution, by just ignoring the covariances.
- `AbstractODEFilterPosterior <: DiffEqBase.AbstractDiffEqInterpolation` is the current interpolant, but it does not actually fully handle the interpolation right now. This part might be subject to change soon.
- *Plot recipe* in `./src/solution_plotting.jl`
- `AbstractODEFilterPosterior <: DiffEqBase.AbstractDiffEqInterpolation` handles the interpolation.
- *Plot recipe* in `./ext/RecipesBaseExt.jl`
- *Sampling* in `./src/solution_sampling.jl`
- `DiffEqBase.prepare_alg(::EK1{0})`; closely follows a similar function implemented in OrdinaryDiffEq.jl `./src/alg_utils.jl`
- this also required `DiffEqBase.remake(::EK1)`

## Other packages
- `DiffEqDevTools.appxtrue`: We overload this function to work with `ProbODESolution`. This also enables `DiffEqDevTools.WorkPrecision` to work out of the box.
- `DiffEqDevTools.appxtrue`: We extend this function to work with `ProbODESolution`. This also enables `DiffEqDevTools.WorkPrecision` to work out of the box.
127 changes: 55 additions & 72 deletions ext/RecipesBaseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,73 +3,63 @@ module RecipesBaseExt
using RecipesBase
using ProbNumDiffEq
using Statistics
import ProbNumDiffEq: AbstractProbODESolution
import SciMLBase: interpret_vars, getsyms

@recipe function f(
sol::AbstractProbODESolution;
idxs=nothing,
denseplot=sol.dense,
plotdensity=1000,
tspan=nothing,
p::AbstractArray{<:Gaussian};
ribbon_width=1.96,
vars=nothing,
)
if vars !== nothing
Base.depwarn(
"To maintain consistency with solution indexing, keyword argument `vars` will be removed in a future version. Please use keyword argument `idxs` instead.",
:f; force=true)
(idxs !== nothing) &&
error(
"Simultaneously using keywords `vars` and `idxs` is not supported. Please only use idxs.",
)
idxs = vars
end

tstart, tend = isnothing(tspan) ? (sol.t[1], sol.t[end]) : tspan
times = denseplot ? range(tstart, tend, length=plotdensity) : sol.t
sol_rvs = denseplot ? sol(times).u : sol.pu
if !isnothing(tspan)
sol_rvs = sol_rvs[tstart.<=times.<=tend]
times = times[tstart.<=times.<=tend]
end
values = stack(mean.(sol_rvs))'
stds = stack(std.(sol_rvs))'
means = mean.(p) |> stack |> permutedims
stddevs = std.(p) |> stack |> permutedims
ribbon --> ribbon_width * stddevs
return means
end
@recipe function f(
x, y::AbstractArray{<:Gaussian};
ribbon_width=1.96,
)
means = mean.(y) |> stack |> permutedims
stddevs = std.(y) |> stack |> permutedims
ribbon --> ribbon_width * stddevs
return x, means
end

if isnothing(idxs)
ribbon --> ribbon_width * stds
xguide --> "t"
yguide --> "u(t)"
label --> hcat(["u$(i)(t)" for i in 1:length(sol.u[1])]...)
return times, values
else
int_vars = interpret_vars(idxs, sol, getsyms(sol))
varsizes = unique(length.(int_vars))
@assert length(varsizes) == 1 "`idxs` argument has an errors"
ndims = varsizes[1] - 1 # First argument is not about dimensions
if ndims == 2
_x = []
_y = []
_labels = []
for (_, i, j) in int_vars
push!(_x, i == 0 ? times : values[:, i])
push!(_y, j == 0 ? times : values[:, j])
end
return _x, _y
elseif ndims == 3
_x = []
_y = []
_z = []
for (_, i, j, k) in int_vars
push!(_x, i == 0 ? times : values[:, i])
push!(_y, j == 0 ? times : values[:, j])
push!(_z, k == 0 ? times : values[:, k])
end
return _x, _y, _z
else
error("Error with `idxs` argument")
end
end
@recipe function f(
x, y::Matrix{<:Gaussian};
ribbon_width=1.96,
)
means = mean.(y)
stddevs = std.(y)
ribbon --> ribbon_width * stddevs
return x, means
end
@recipe function f(
x::Matrix{<:Gaussian}, y::Matrix{<:Gaussian},
)
@warn "This plot does not visualize any uncertainties"
xmeans = mean.(x)
ymeans = mean.(y)
return xmeans, ymeans
end
@recipe function f(
x,
y::Matrix{<:Gaussian},
z::Matrix{<:Gaussian},
)
@warn "This plot does not visualize any uncertainties"
ymeans = mean.(y)
zmeans = mean.(z)
return x, ymeans, zmeans
end
@recipe function f(
x::Matrix{<:Gaussian},
y::Matrix{<:Gaussian},
z::Matrix{<:Gaussian},
)
@warn "This plot does not visualize any uncertainties"
xmeans = mean.(x)
ymeans = mean.(y)
zmeans = mean.(z)
return xmeans, ymeans, zmeans
end

@recipe function f(
Expand All @@ -78,19 +68,13 @@ end
N_samples=10,
plot_derivatives=false,
)
marginals = ProbNumDiffEq.marginalize(process, plotrange)
d = ProbNumDiffEq.dim(process)
q = ProbNumDiffEq.num_derivatives(process)
means = [mean(m) for m in marginals] |> stack |> permutedims
stddevs = [std(m) for m in marginals] |> stack |> permutedims

perm = permutedims(reshape(collect(1:d*(q+1)), q + 1, d))[:]
reorder(X) = X[:, perm]
marginals = ProbNumDiffEq.marginalize(process, plotrange)

E0 = ProbNumDiffEq.projection(d, q)(0)
if !plot_derivatives
stddevs = stddevs * E0'
means = means * E0'
marginals = [E0 * m for m in marginals]
q = 0
end

Expand All @@ -106,15 +90,14 @@ end
end

@series begin
ribbon --> 3stddevs
label --> ""
fillalpha --> 0.1
layout --> if plot_derivatives
(d, q + 1)
else
d
end
plotrange, means
plotrange, marginals
end

if N_samples > 0
Expand Down
1 change: 1 addition & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Base: copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate
using LinearAlgebra
import LinearAlgebra: mul!
import Statistics: mean, var, std, cov
using Printf

using Reexport
@reexport using DiffEqBase
Expand Down
4 changes: 2 additions & 2 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function OrdinaryDiffEq.alg_cache(
is_secondorder_ode = f isa DynamicalODEFunction

q = num_derivatives(alg.prior)
d = is_secondorder_ode ? length(u[1, :]) : length(u)
d = is_secondorder_ode ? length(u.x[1]) : length(u)
D = d * (q + 1)

uType = typeof(u)
Expand Down Expand Up @@ -175,7 +175,7 @@ function OrdinaryDiffEq.alg_cache(
measurement = Gaussian(v, S)

# Caches
du = is_secondorder_ode ? similar(u[2, :]) : similar(u)
du = is_secondorder_ode ? similar(u.x[2]) : similar(u)
ddu = factorized_similar(FAC, length(u), length(u))
_d = is_secondorder_ode ? 2d : d
pu_tmp = Gaussian(
Expand Down
2 changes: 1 addition & 1 deletion src/initialization/autodiffinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function initial_update!(integ, cache, init::AutodiffInitializationScheme)
for (o, df) in zip(0:q, f_derivatives)
if f isa DynamicalODEFunction
@assert df isa ArrayPartition
df = df[2, :]
df = df.x[2]
end

df = view(df, :)
Expand Down
2 changes: 1 addition & 1 deletion src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
is_secondorder = integ.f isa DynamicalODEFunction
_u = is_secondorder ? view(u.x[2], :) : view(u, :)
init_condition_on!(x, Proj(0), _u, cache)
is_secondorder ? f.f1(du, u.x[1], u.x[2], p, t) : f(du, u, p, t)
is_secondorder ? u.x[1] : f(du, u, p, t)
integ.stats.nf += 1
init_condition_on!(x, Proj(1), view(du, :), cache)

Expand Down
4 changes: 2 additions & 2 deletions src/initialization/simpleinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ function initial_update!(integ, cache, init::SimpleInit)

if f isa DynamicalODEFunction
@assert u isa ArrayPartition
u = u[2, :]
u = u.x[2]
@assert du isa ArrayPartition
du = du[2, :]
du = du.x[2]
end

init_condition_on!(x, Proj(0), view(u, :), cache)
Expand Down
2 changes: 1 addition & 1 deletion src/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ The actual smoothing step happens by [`marginalize!`](@ref)ing backward kernels.
"""
function smooth_solution!(integ)
@unpack cache, sol = integ
sol.x_smooth = copy(sol.x_filt)
append!(sol.x_smooth, sol.x_filt)

@unpack x_smooth, t, backward_kernels = sol
@unpack C_DxD, C_3DxD = cache
Expand Down
2 changes: 1 addition & 1 deletion src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function OrdinaryDiffEq.perform_step!(integ, cache::EKCache, repeat_step=false)
integ.u, x_filt.μ; cache, is_secondorder_ode=integ.f isa DynamicalODEFunction)

cache.log_likelihood = loglikelihood
integ.sol.log_likelihood += cache.log_likelihood
integ.sol.pnstats.log_likelihood += cache.log_likelihood

# Update the global diffusion MLE (if applicable)
if !isdynamic(cache.diffusionmodel)
Expand Down
6 changes: 2 additions & 4 deletions src/priors/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,8 @@ resulting Gaussian distributions.
See also: [`sample`](@ref).
"""
function marginalize(process::AbstractGaussMarkovProcess, times)
out = []
X = initial_distribution(process)
push!(out, X)
out = [X]
for i in 2:length(times)
dt = times[i] - times[i-1]
A, Q = ProbNumDiffEq.discretize(process, dt)
Expand All @@ -238,11 +237,10 @@ Samples from the Gauss-Markov process on the given time grid.
See also: [`marginalize`](@ref).
"""
function sample(process::AbstractGaussMarkovProcess, times, N=1)
out = []
X = initial_distribution(process)
X = Gaussian(mean(X), Matrix(cov(X)))
s = rand(X, N)
push!(out, s)
out = [s]
for i in 2:length(times)
dt = times[i] - times[i-1]
A, Q = Matrix.(discretize(process, dt))
Expand Down
2 changes: 1 addition & 1 deletion src/priors/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ initial_distribution(p::Matern{T}) where {T} = begin
D = d * (q + 1)
sde = to_sde(p)
μ0 = T <: LinearAlgebra.BlasFloat ? Array{T}(calloc, D) : zeros(T, D)
Σ0 = PSDMatrix(plyapc(sde.F, sde.L)')
Σ0 = PSDMatrix(IsometricKroneckerProduct(d, Matrix(plyapc(sde.F.B, sde.L.B)')))
return Gaussian(μ0, Σ0)
end

Expand Down
6 changes: 4 additions & 2 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ function projection(
D = d * (q + 1)
Proj(deriv) = begin
P = zeros(elType, d, D)
@simd ivdep for i in deriv*d+1:D+1:d*D
@inbounds P[i] = 1
if deriv <= q
@simd ivdep for i in deriv*d+1:D+1:d*D
@inbounds P[i] = 1
end
end
return P
end
Expand Down
Loading

0 comments on commit 6935a52

Please sign in to comment.