diff --git a/src/filtering/markov_kernel.jl b/src/filtering/markov_kernel.jl index e980108dd..1884430a2 100644 --- a/src/filtering/markov_kernel.jl +++ b/src/filtering/markov_kernel.jl @@ -120,6 +120,28 @@ function marginalize_cov!( return marginalize_cov!(_Σ_out, _Σ_curr, _K; C_DxD=_C_DxD, C_3DxD=_C_3DxD) end +function marginalize_cov!( + Σ_out::PSDMatrix{T,<:BlockDiagonal}, + Σ_curr::PSDMatrix{T,<:BlockDiagonal}, + K::AffineNormalKernel{ + <:AbstractMatrix, + <:Any, + <:PSDMatrix{S,<:BlockDiagonal}, + }; + C_DxD::AbstractMatrix, + C_3DxD::AbstractMatrix, +) where {T,S} + for i in eachindex(blocks(Σ_out.R)) + _Σ_out = PSDMatrix(Σ_out.R.blocks[i]) + _Σ_curr = PSDMatrix(Σ_curr.R.blocks[i]) + _K = AffineNormalKernel(K.A.blocks[i], K.b, PSDMatrix(K.C.R.blocks[i])) + _C_DxD = C_DxD.blocks[i] + _C_3DxD = C_3DxD.blocks[i] + marginalize_cov!(_Σ_out, _Σ_curr, _K; C_DxD=_C_DxD, C_3DxD=_C_3DxD) + end + return Σ_out +end + """ compute_backward_kernel!(Kout, xpred, x, K; C_DxD[, diffusion=1]) @@ -243,3 +265,52 @@ function compute_backward_kernel!( return compute_backward_kernel!( _Kout, _x_pred, _x, _K; C_DxD=_C_DxD, diffusion=diffusion) end + +function compute_backward_kernel!( + Kout::KT1, + xpred::SRGaussian{T,<:BlockDiagonal}, + x::SRGaussian{T,<:BlockDiagonal}, + K::KT2; + C_DxD::AbstractMatrix, + diffusion=1, +) where { + T, + KT1<:AffineNormalKernel{ + <:BlockDiagonal, + <:AbstractVector, + <:PSDMatrix{T,<:BlockDiagonal}, + }, + KT2<:AffineNormalKernel{ + <:BlockDiagonal, + <:Any, + <:PSDMatrix{T,<:BlockDiagonal}, + }, +} + d = length(blocks(xpred.Σ.R)) + q = size(blocks(xpred.Σ.R)[1], 1) - 1 + for i in eachindex(blocks(xpred.Σ.R)) + _Kout = AffineNormalKernel( + Kout.A.blocks[i], + view(Kout.b, (i-1)*(q+1)+1:i*(q+1)), + PSDMatrix(Kout.C.R.blocks[i]) + ) + _xpred = Gaussian( + view(xpred.μ, (i-1)*(q+1)+1:i*(q+1)), + PSDMatrix(xpred.Σ.R.blocks[i]) + ) + _x = Gaussian( + view(x.μ, (i-1)*(q+1)+1:i*(q+1)), + PSDMatrix(x.Σ.R.blocks[i]) + ) + _K = AffineNormalKernel( + K.A.blocks[i], + ismissing(K.b) ? missing : view(K.b, (i-1)*(q+1)+1:i*(q+1)), + PSDMatrix(K.C.R.blocks[i]) + ) + _C_DxD = C_DxD.blocks[i] + compute_backward_kernel!( + _Kout, _xpred, _x, _K, C_DxD=_C_DxD, diffusion=diffusion + ) + end + return Kout +end diff --git a/src/filtering/predict.jl b/src/filtering/predict.jl index 630f91dca..06bb67b57 100644 --- a/src/filtering/predict.jl +++ b/src/filtering/predict.jl @@ -135,4 +135,5 @@ function predict_cov!( diffusion, ) end + return Σ_out end