diff --git a/src/filtering/markov_kernel.jl b/src/filtering/markov_kernel.jl index 9223a84f7..ec873378a 100644 --- a/src/filtering/markov_kernel.jl +++ b/src/filtering/markov_kernel.jl @@ -69,39 +69,39 @@ Note that this function assumes certain shapes: `xout` is assumes to have the same shapes as `x`. """ function marginalize!(xout, x, K; C_DxD, C_3DxD) - marginalize_mean!(xout, x, K) - marginalize_cov!(xout, x, K; C_DxD, C_3DxD) + marginalize_mean!(xout.μ, x.μ, K) + marginalize_cov!(xout.Σ, x.Σ, K; C_DxD, C_3DxD) end -function marginalize_mean!(xout::Gaussian, x::Gaussian, K::AffineNormalKernel) - _matmul!(xout.μ, K.A, x.μ) +function marginalize_mean!(μout::AbstractVecOrMat, μ::AbstractVecOrMat, K::AffineNormalKernel) + _matmul!(μout, K.A, μ) if !ismissing(K.b) - xout.μ .+= K.b + μout .+= K.b end - return xout.μ + return μout end function marginalize_cov!( - x_out::SRGaussian, - x_curr::SRGaussian, + Σ_out::PSDMatrix, + Σ_curr::PSDMatrix, K::AffineNormalKernel{<:AbstractMatrix,<:Any,<:PSDMatrix}; C_DxD::AbstractMatrix, C_3DxD::AbstractMatrix, ) - _D = size(x_curr.Σ, 1) - A, b, C = K - R, M = C_3DxD, C_DxD + _D = size(Σ_curr, 1) + A, _, C = K + R = C_3DxD - _matmul!(view(R, 1:_D, 1:_D), x_curr.Σ.R, A') + _matmul!(view(R, 1:_D, 1:_D), Σ_curr.R, A') @.. R[_D+1:3_D, 1:_D] = C.R Q_R = triangularize!(R, cachemat=C_DxD) - copy!(x_out.Σ.R, Q_R) - return x_out.Σ + copy!(Σ_out.R, Q_R) + return Σ_out end function marginalize_cov!( - x_out::SRGaussian{T,<:IKP}, - x_curr::SRGaussian{T,<:IKP}, + Σ_out::PSDMatrix{T,<:IKP}, + Σ_curr::PSDMatrix{T,<:IKP}, K::AffineNormalKernel{ <:AbstractMatrix, <:Any, @@ -110,13 +110,13 @@ function marginalize_cov!( C_DxD::AbstractMatrix, C_3DxD::AbstractMatrix, ) where {T,S} - _x_out = Gaussian(x_out.μ, PSDMatrix(x_out.Σ.R.B)) - _x_curr = Gaussian(x_curr.μ, PSDMatrix(x_curr.Σ.R.B)) + _Σ_out = PSDMatrix(Σ_out.R.B) + _Σ_curr = PSDMatrix(Σ_curr.R.B) _K = AffineNormalKernel(K.A.B, K.b, PSDMatrix(K.C.R.B)) - _D = size(_x_out.Σ, 1) + _D = size(_Σ_out, 1) _C_DxD = view(C_DxD, 1:_D, 1:_D) _C_3DxD = view(C_3DxD, 1:3*_D, 1:_D) - return marginalize_cov!(_x_out, _x_curr, _K; C_DxD=_C_DxD, C_3DxD=_C_3DxD) + return marginalize_cov!(_Σ_out, _Σ_curr, _K; C_DxD=_C_DxD, C_3DxD=_C_3DxD) end """ diff --git a/test/core/filtering.jl b/test/core/filtering.jl index 4bfee2b17..e713b6318 100644 --- a/test/core/filtering.jl +++ b/test/core/filtering.jl @@ -348,8 +348,8 @@ end @test Matrix(K_backward.C) ≈ Λ C_3DxD = zeros(3d, d) - ProbNumDiffEq.marginalize_mean!(x_curr, x_next_smoothed, K_backward) - ProbNumDiffEq.marginalize_cov!(x_curr, x_next_smoothed, K_backward; C_DxD, C_3DxD) + ProbNumDiffEq.marginalize_mean!(x_curr.μ, x_next_smoothed.μ, K_backward) + ProbNumDiffEq.marginalize_cov!(x_curr.Σ, x_next_smoothed.Σ, K_backward; C_DxD, C_3DxD) @test m_smoothed ≈ x_curr.μ @test P_smoothed ≈ Matrix(x_curr.Σ)