Skip to content

Commit

Permalink
Make ClassicSovlerInit use backward transitions and completely remove…
Browse files Browse the repository at this point in the history
… `smooth!` from the package (#262)
  • Loading branch information
nathanaelbosch authored Nov 1, 2023
1 parent 72e382c commit 9e8c8e0
Show file tree
Hide file tree
Showing 11 changed files with 1,861 additions and 1,961 deletions.
266 changes: 133 additions & 133 deletions docs/src/benchmarks/figures/lotkavolterra_2_1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
608 changes: 304 additions & 304 deletions docs/src/benchmarks/figures/lotkavolterra_3_1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
612 changes: 306 additions & 306 deletions docs/src/benchmarks/figures/lotkavolterra_4_1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
724 changes: 362 additions & 362 deletions docs/src/benchmarks/figures/lotkavolterra_5_1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
688 changes: 344 additions & 344 deletions docs/src/benchmarks/figures/lotkavolterra_6_1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
752 changes: 376 additions & 376 deletions docs/src/benchmarks/figures/lotkavolterra_7_1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 16 additions & 15 deletions docs/src/benchmarks/lotkavolterra.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ Platform Info:
Environment:
JULIA_NUM_THREADS = auto
JULIA_STACKTRACE_MINIMAL = true
JULIA_IMAGE_THREADS = 1
```


Expand All @@ -301,10 +302,10 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Project.toml`
[65888b18] ParameterizedFunctions v5.16.0
[91a5bcdd] Plots v1.39.0
[bf3e78b0] ProbNumDiffEq v0.12.1 `~/.julia/dev/ProbNumDiffEq`
[0bca4576] SciMLBase v2.4.3
[0bca4576] SciMLBase v2.5.0
[505e40e9] SciPyDiffEq v0.2.1
[90137ffa] StaticArrays v1.6.5
[c3572dad] Sundials v4.20.0
[c3572dad] Sundials v4.20.1
[44d3d7a6] Weave v0.10.12
[0518478a] deSolveDiffEq v0.1.1
```
Expand All @@ -324,16 +325,16 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
⌅ [c3fe647b] AbstractAlgebra v0.32.5
[621f4979] AbstractFFTs v1.5.0
[1520ce14] AbstractTrees v0.4.4
[79e6a3ab] Adapt v3.7.0
[79e6a3ab] Adapt v3.7.1
[ec485272] ArnoldiMethod v0.2.0
[c9d4266f] ArrayAllocators v0.3.0
[4fba245c] ArrayInterface v7.4.11
[4fba245c] ArrayInterface v7.5.0
[30b0a656] ArrayInterfaceCore v0.1.29
[6e4b80f9] BenchmarkTools v1.3.2
[e2ed5e7c] Bijections v0.1.6
[d1d4a3ce] BitFlags v0.1.7
[62783981] BitTwiddlingConvenienceFunctions v0.1.5
[fa961155] CEnum v0.4.2
[fa961155] CEnum v0.5.0
[2a0fbf3d] CPUSummary v0.2.4
[00ebfdb7] CSTParser v3.3.6
[49dc2e85] Calculus v0.5.1
Expand Down Expand Up @@ -365,7 +366,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[864edb3b] DataStructures v0.18.15
[e2d170a0] DataValueInterfaces v1.0.0
[8bb1440f] DelimitedFiles v1.9.1
[2b5f629d] DiffEqBase v6.134.0
[2b5f629d] DiffEqBase v6.135.0
[459566f4] DiffEqCallbacks v2.33.1
[f3b72e0c] DiffEqDevTools v2.39.0
[77a26b50] DiffEqNoiseProcess v5.19.0
Expand All @@ -379,7 +380,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[7c1d4256] DynamicPolynomials v0.5.3
[b305315f] Elliptic v1.0.1
[4e289a0a] EnumX v1.0.4
[f151be2c] EnzymeCore v0.6.2
[f151be2c] EnzymeCore v0.6.3
[6912e4f1] Espresso v0.6.1
[460bff9d] ExceptionUnwrapping v0.1.9
[d4d017d3] ExponentialUtilities v1.25.0
Expand All @@ -405,7 +406,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[c27321d9] Glob v1.3.1
[86223c79] Graphs v1.9.0
[42e2da0e] Grisu v1.0.2
[0b43b601] Groebner v0.4.4
[0b43b601] Groebner v0.4.4
[d5909c97] GroupsCore v0.4.0
[cd3eb016] HTTP v1.10.0
[eafb193a] Highlights v0.5.2
Expand All @@ -424,7 +425,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[1019f520] JLFzf v0.1.6
[692b3bcd] JLLWrappers v1.5.0
[682c06a0] JSON v0.21.4
[98e50ef6] JuliaFormatter v1.0.40
[98e50ef6] JuliaFormatter v1.0.41
[ccbc3e58] JumpProcesses v9.8.0
[ef3ab10e] KLU v0.4.1
[2c470bb0] Kronecker v0.5.4
Expand All @@ -438,7 +439,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[50d2b5c4] Lazy v0.15.1
[1d6d02ad] LeftChildRightSiblingTrees v0.2.0
[d3d80556] LineSearches v7.2.0
[7ed4a6bd] LinearSolve v2.12.1
[7ed4a6bd] LinearSolve v2.14.0
[2ab3a3ac] LogExpFunctions v0.3.26
[e6f89c97] LoggingExtras v1.0.3
[bdcacae8] LoopVectorization v0.12.165
Expand All @@ -459,7 +460,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[2774e3e8] NLsolve v4.5.1
[77ba4419] NaNMath v1.0.2
⌅ [356022a1] NamedDims v0.2.50
[8913a72c] NonlinearSolve v2.4.0
[8913a72c] NonlinearSolve v2.5.0
[54ca160b] ODEInterface v0.5.0
[09606e27] ODEInterfaceDiffEq v3.13.3
[6fd5a793] Octavian v0.3.27
Expand Down Expand Up @@ -511,8 +512,8 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[7e49a35a] RuntimeGeneratedFunctions v0.5.12
[fdea26ae] SIMD v3.4.5
[94e857df] SIMDTypes v0.1.0
[476501e8] SLEEFPirates v0.6.39
[0bca4576] SciMLBase v2.4.3
[476501e8] SLEEFPirates v0.6.40
[0bca4576] SciMLBase v2.5.0
[e9a6253c] SciMLNLSolve v0.1.9
[c0aeaf25] SciMLOperators v0.3.6
[505e40e9] SciPyDiffEq v0.2.1
Expand All @@ -528,7 +529,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[66db9d55] SnoopPrecompile v1.0.3
[b85f4697] SoftGlobalScope v1.1.0
[a2af1166] SortingAlgorithms v1.2.0
[47a9eef4] SparseDiffTools v2.8.0
[47a9eef4] SparseDiffTools v2.9.1
[e56a9233] Sparspak v0.3.9
[276daf66] SpecialFunctions v2.3.1
[928aab9d] SpecialMatrices v3.0.0
Expand All @@ -544,7 +545,7 @@ Status `~/.julia/dev/ProbNumDiffEq/benchmarks/Manifest.toml`
[69024149] StringEncodings v0.3.7
[892a3eda] StringManipulation v0.3.4
[09ab397b] StructArrays v0.6.16
[c3572dad] Sundials v4.20.0
[c3572dad] Sundials v4.20.1
[2efcf032] SymbolicIndexingInterface v0.2.2
[d1185830] SymbolicUtils v1.4.0
[0c5d862f] Symbolics v5.10.0
Expand Down
8 changes: 7 additions & 1 deletion docs/src/filtering.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ ProbNumDiffEq.update!

```@docs
ProbNumDiffEq.smooth
ProbNumDiffEq.smooth!
```

## Markov Kernels
```@docs
ProbNumDiffEq.AffineNormalKernel
ProbNumDiffEq.marginalize!
ProbNumDiffEq.compute_backward_kernel!
```
86 changes: 0 additions & 86 deletions src/filtering/smooth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ G &= Σ_n^S A^T (Σ_{n+1}^P)^{-1}, \\\\
and return a smoothed state `\\mathcal{N}(μ_n^S, Σ_n^S)`.
When called with `ProbNumDiffEq.SquarerootMatrix` type arguments it performs the update in
Joseph / square-root form.
For better performance, we recommend to use the non-allocating [`smooth!`](@ref).
"""
function smooth(
x_curr::Gaussian,
Expand Down Expand Up @@ -65,87 +63,3 @@ function smooth(
x_curr_smoothed = Gaussian(smoothed_mean, smoothed_cov)
return x_curr_smoothed, G
end

"""
smooth!(x_curr, x_next, Ah, Qh, cache, diffusion=1)
In-place and square-root implementation of [`smooth`](@ref) which overwrites `x_curr`.
Implemented in Joseph form to preserve square-root structure.
It requires access to the solvers `cache`
to prevent allocations.
See also: [`smooth`](@ref).
"""
function smooth!(
x_curr::SRGaussian,
x_next::SRGaussian,
Ah::AbstractMatrix,
Qh::PSDMatrix,
cache,
diffusion::Union{Number,Diagonal}=1,
)
# x_curr is the state at time t_n (filter estimate) that we want to smooth
# x_next is the state at time t_{n+1}, already smoothed, which we use for smoothing
@unpack x_pred = cache
@unpack G1, C_DxD, C_2DxD, C_3DxD = cache
D = length(x_curr.μ)
_D = size(C_DxD, 1)

# Prediction: t -> t+1
predict_mean!(x_pred.μ, x_curr.μ, Ah)
predict_cov!(x_pred.Σ, x_curr.Σ, Ah, Qh, C_DxD, C_2DxD, diffusion)

# Smoothing
# G = x_curr.Σ * Ah' * P_p_inv
P_p_chol = Cholesky(x_pred.Σ.R, :U, 0)
G = rdiv!(_matmul!(G1, x_curr.Σ.R', _matmul!(C_DxD, x_curr.Σ.R, Ah')), P_p_chol)

# x_curr.μ .+= G * (x_next.μ .- x_pred.μ) # less allocations:
x_pred.μ .-= x_next.μ
_matmul!(x_curr.μ, G, x_pred.μ, -1, 1)

# Joseph-Form:
R = C_3DxD

G2 = _matmul!(C_DxD, G, Ah)
copy!(view(R, 1:_D, 1:_D), x_curr.Σ.R)
_matmul!(view(R, 1:_D, 1:_D), x_curr.Σ.R, G2', -1.0, 1.0)

_matmul!(view(R, _D+1:2_D, 1:_D), Qh.R, _matmul!(G2, G, sqrt.(diffusion))')
_matmul!(view(R, 2_D+1:3_D, 1:_D), x_next.Σ.R, G')

Q_R = triangularize!(R, cachemat=C_DxD)
copy!(x_curr.Σ.R, Q_R)

return nothing
end

function smooth!(
x_curr::SRGaussian{T,<:IsometricKroneckerProduct},
x_next::SRGaussian{T,<:IsometricKroneckerProduct},
Ah::IsometricKroneckerProduct,
Qh::PSDMatrix{S,<:IsometricKroneckerProduct},
cache,
diffusion::Union{Number,Diagonal}=1,
) where {T,S}
D = length(x_curr.μ) # full_state_dim
d = Ah.ldim # ode_dimension_dim
Q = D ÷ d # n_derivatives_dim
_x_curr = Gaussian(reshape_no_alloc(x_curr.μ, Q, d), PSDMatrix(x_curr.Σ.R.B))
_x_next = Gaussian(reshape_no_alloc(x_next.μ, Q, d), PSDMatrix(x_next.Σ.R.B))
_Ah = Ah.B
_Qh = PSDMatrix(Qh.R.B)
_cache = (
G1=cache.G1.B,
C_DxD=cache.C_DxD.B,
C_2DxD=cache.C_2DxD.B,
C_3DxD=cache.C_3DxD.B,
x_pred=Gaussian(
reshape_no_alloc(cache.x_pred.μ, Q, d),
PSDMatrix(cache.x_pred.Σ.R.B),
),
)

return smooth!(_x_curr, _x_next, _Ah, _Qh, _cache, diffusion)
end
19 changes: 13 additions & 6 deletions src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ end

function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt)
@unpack A, Q = cache
# @unpack Ah, Qh = cache
@unpack x, x_pred, x_filt, measurement = cache
@unpack K1, C_Dxd, C_DxD, C_dxd = cache
@unpack K1, C_Dxd, C_DxD, C_dxd, C_3DxD = cache
@unpack backward_kernel = cache

# Predict forward:
make_preconditioners!(cache, dt)
Expand All @@ -100,14 +102,21 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt)

preds = []
filts = [copy(x)]
backward_kernels = []

# Filter through the data forwards
for (i, (t, u)) in enumerate(zip(ts, us))
(u isa RecursiveArrayTools.ArrayPartition) && (u = u.x[2]) # for 2ndOrderODEs
u = view(u, :) # just in case the problem is matrix-valued

predict!(x_pred, x, A, Q, cache.C_DxD, cache.C_2DxD, cache.default_diffusion)
push!(preds, copy(x_pred))

K = AffineNormalKernel(A, Q)
compute_backward_kernel!(
backward_kernel, x_pred, x, K; C_DxD, diffusion=cache.default_diffusion)
push!(backward_kernels, copy(backward_kernel))

H = cache.E0 * PI
measurement.μ .= H * x_pred.μ .- u
fast_X_A_Xt!(measurement.Σ, x_pred.Σ, H)
Expand All @@ -119,11 +128,9 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt)
end

# Smooth backwards
for i in length(filts):-1:2
xf = filts[i-1]
xs = filts[i]
xp = preds[i-1] # Since `preds` is one shorter
smooth!(xf, xs, A, Q, cache, 1)
x_smooth = filts
for i in length(x_smooth)-1:-1:1
marginalize!(x_smooth[i], x_smooth[i+1], backward_kernels[i]; C_DxD, C_3DxD)
end

_gaussian_mul!(cache.x, PI, filts[1])
Expand Down
28 changes: 0 additions & 28 deletions test/core/filtering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,34 +319,6 @@ end
@test m_smoothed x_out.μ
@test P_smoothed Matrix(x_out.Σ)
end
@testset "smooth!" begin
_d = d
x_curr_psd = Gaussian(m, PSDMatrix(P_R)) |> copy
x_next_psd = Gaussian(m_s, PSDMatrix(P_s_R)) |> copy
cache = if !KRONECKER
(x_pred=copy(x_curr_psd),
G1=zeros(_d, _d),
C_DxD=zeros(_d, _d),
C_2DxD=zeros(2_d, _d),
C_3DxD=zeros(3_d, _d))
else
(x_pred=copy(x_curr_psd),
G1=IsometricKroneckerProduct(K, zeros(_d, _d)),
C_DxD=IsometricKroneckerProduct(K, zeros(_d, _d)),
C_2DxD=IsometricKroneckerProduct(K, zeros(2_d, _d)),
C_3DxD=IsometricKroneckerProduct(K, zeros(3_d, _d)))
end
ProbNumDiffEq.smooth!(
x_curr_psd,
x_next_psd,
A,
Q_SR,
cache,
)
@test m_smoothed x_curr_psd.μ
@test P_smoothed Matrix(x_curr_psd.Σ)
end

@testset "smooth via backward kernels" begin
K_forward = ProbNumDiffEq.AffineNormalKernel(copy(A), copy(Q_SR))
K_backward = ProbNumDiffEq.AffineNormalKernel(
Expand Down

0 comments on commit 9e8c8e0

Please sign in to comment.