Skip to content

Commit

Permalink
Make some filtering functions a bit easier to read
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Oct 20, 2023
1 parent 2c63247 commit 7085943
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 19 deletions.
20 changes: 9 additions & 11 deletions src/filtering/markov_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,31 +181,29 @@ function compute_backward_kernel!(
A, _, Q = K
G, b, Λ = Kout

D = length(x.μ)
_D = size(G, 1)
_a = D ÷ _D
D = output_dim = size(G, 1)

# G = Matrix(x.Σ) * A' / Matrix(xpred.Σ)
_matmul!(C_DxD, x.Σ.R, A')
_matmul!(G, x.Σ.R', C_DxD)
rdiv!(G, Cholesky(xpred.Σ.R, 'U', 0))

# b = μ - G * μ_pred
_matmul!(reshape_no_alloc(b, _D, _a), G, reshape_no_alloc(xpred.μ, _D, _a))
_matmul!(b, G, xpred.μ)
b .= x.μ .- b

# Λ.R[1:D, 1:D] = x.Σ.R * (I - G * A)'
_matmul!(C_DxD, A', G', -1.0, 0.0)
@inbounds @simd ivdep for i in 1:_D
@inbounds @simd ivdep for i in 1:D
C_DxD[i, i] += 1
end
_matmul!(view.R, 1:_D, 1:_D), x.Σ.R, C_DxD)
_matmul!(view.R, 1:D, 1:D), x.Σ.R, C_DxD)
# Λ.R[D+1:2D, 1:D] = (G * Q.R')'
if !isone(diffusion)
_matmul!(C_DxD, Q.R, sqrt.(diffusion))
_matmul!(view.R, _D+1:2_D, 1:_D), C_DxD, G')
_matmul!(view.R, D+1:2D, 1:D), C_DxD, G')
else
_matmul!(view.R, _D+1:2_D, 1:_D), Q.R, G')
_matmul!(view.R, D+1:2D, 1:D), Q.R, G')
end

return Kout
Expand All @@ -223,9 +221,9 @@ function compute_backward_kernel!(
KT1<:AffineNormalKernel{<:AbstractMatrix,<:AbstractVector,<:PSDMatrix},
KT2<:AffineNormalKernel{<:AbstractMatrix,<:Any,<:PSDMatrix},
}
D = length(x.μ)
d = K.A.ldim
Q = D ÷ d
D = full_state_dim = length(x.μ)
d = ode_dimension_dim = K.A.ldim
Q = n_derivatives_dim = D ÷ d
_Kout = AffineNormalKernel(Kout.A.B, reshape_no_alloc(Kout.b, Q, d), PSDMatrix(Kout.C.R.B))
_x_pred = Gaussian(reshape_no_alloc(xpred.μ, Q, d), PSDMatrix(xpred.Σ.R.B))
_x = Gaussian(reshape_no_alloc(x.μ, Q, d), PSDMatrix(x.Σ.R.B))
Expand Down
2 changes: 1 addition & 1 deletion src/filtering/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function predict_cov!(
return Σ_out
end
R, M = C_2DxD, C_DxD
D, D = size(Qh)
D = size(Qh, 1)

_matmul!(view(R, 1:D, 1:D), Σ_curr.R, Ah')
_matmul!(view(R, D+1:2D, 1:D), Qh.R, sqrt.(diffusion))
Expand Down
6 changes: 3 additions & 3 deletions src/filtering/smooth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ function smooth!(
cache,
diffusion::Union{Number,Diagonal}=1,
) where {T,S}
D = length(x_curr.μ)
d = Ah.ldim
Q = D ÷ d
D = full_state_dim = length(x_curr.μ)
d = ode_dimension_dim = Ah.ldim
Q = n_derivatives_dim = D ÷ d
_x_curr = Gaussian(reshape_no_alloc(x_curr.μ, Q, d), PSDMatrix(x_curr.Σ.R.B))
_x_next = Gaussian(reshape_no_alloc(x_next.μ, Q, d), PSDMatrix(x_next.Σ.R.B))
_Ah = Ah.B
Expand Down
8 changes: 4 additions & 4 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ function update!(
M_cache::AbstractMatrix,
C_dxd::AbstractMatrix,
) where {T}
D = length(x_out.μ)
d = H.ldim
Q = D ÷ d
D = full_state_dim = length(x_out.μ)
d = ode_dimension_dim = H.ldim
Q = n_derivatives_dim = D ÷ d
_x_out = Gaussian(reshape_no_alloc(x_out.μ, Q, d), PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian(reshape_no_alloc(x_pred.μ, Q, d), PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian(
reshape_no_alloc(measurement.μ, 1, d), PSDMatrix(measurement.Σ.R.B))
o = size(_measurement.Σ, 1)
o = measurement_dim = size(_measurement.Σ, 1)
_H = H.B
_D = length(x_out.μ) ÷ d
_K1_cache = view(K1_cache, 1:_D, 1:o)
Expand Down

0 comments on commit 7085943

Please sign in to comment.