Skip to content

Commit

Permalink
Add support for PSDMatrices as observation noise
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 6, 2024
1 parent 3f00f06 commit 0093138
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/callbacks/dataupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ function DataUpdateCallback(
observation_noise_cov
end

_A = x.Σ.R * H'
# obs_cov = PSDMatrix(qr!([x.Σ.R * H'; sqrt(observation_noise_cov)]).R)
obs_cov = _A'_A + R
obs_cov = if observation_noise_cov isa PSDMatrix
PSDMatrix(qr!([x.Σ.R * H'; observation_noise_cov.R]).R)
else
_A = x.Σ.R * H'
_A'_A + R
end
obs = Gaussian(obs_mean, obs_cov)

_x = copy!(integ.cache.x_tmp, x)
Expand Down
20 changes: 15 additions & 5 deletions src/data_likelihoods/fenrir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,13 @@ function measure!(x, H, R, m_tmp)
z, S = m_tmp
_matmul!(z, H, x.μ)
fast_X_A_Xt!(S, x.Σ, H)
_S = Matrix(S) .+= R
return Gaussian(z, Symmetric(_S))
if R isa PSDMatrix
_S = PSDMatrix(qr([S.R; R.R]).R)
return Gaussian(z, _S)
else
_S = Matrix(S) .+= R
return Gaussian(z, Symmetric(_S))
end
end

function fenrir_update!(
Expand Down Expand Up @@ -179,7 +184,7 @@ function fenrir_update!(
end

S_chol = try
cholesky!(_S)
cholesky!(Symmetric(Matrix(_S)))
catch e
if !(e isa PosDefException)
rethrow(e)
Expand All @@ -205,8 +210,13 @@ function fenrir_update!(
fast_X_A_Xt!(x_out.Σ, P_p, M_cache)

if !iszero(R)
out_Sigma_R = [x_out.Σ.R; cholesky(R).U * K']
x_out.Σ.R .= triangularize!(out_Sigma_R; cachemat=M_cache)
if R isa PSDMatrix
out_Sigma_R = [x_out.Σ.R; R.R * K']
x_out.Σ.R .= triangularize!(out_Sigma_R; cachemat=M_cache)
else
out_Sigma_R = [x_out.Σ.R; cholesky(R).U * K']
x_out.Σ.R .= triangularize!(out_Sigma_R; cachemat=M_cache)
end
end

return x_out, loglikelihood
Expand Down
1 change: 1 addition & 0 deletions test/data_likelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ end
σ^2 * Eye(2),
^2 0; 0 2σ^2],
(A = randn(2, 2); A'A),
(PSDMatrix(randn(2, 2))),
)
compare_data_likelihoods(
EK1();
Expand Down

0 comments on commit 0093138

Please sign in to comment.