From 786b1edb52fdaa75ad0eee9263b21e261c8bc3c5 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Sat, 17 Feb 2024 12:10:56 +0100 Subject: [PATCH] Get the data likelihoods to work with DiagonalEK1 and the EK0 --- src/ProbNumDiffEq.jl | 2 ++ src/blockdiagonals.jl | 19 +++++++++++++ src/callbacks/dataupdate.jl | 33 ++++++++++++++--------- src/covariance_structure.jl | 9 +++++++ src/data_likelihoods/fenrir.jl | 17 ++++++------ src/filtering/update.jl | 2 +- test/data_likelihoods.jl | 49 +++++++++++++++++++++++----------- 7 files changed, 93 insertions(+), 38 deletions(-) diff --git a/src/ProbNumDiffEq.jl b/src/ProbNumDiffEq.jl index 3f1e7fde8..a6771bdfc 100644 --- a/src/ProbNumDiffEq.jl +++ b/src/ProbNumDiffEq.jl @@ -46,6 +46,8 @@ vecvec2mat(x) = reduce(hcat, x)' cov2psdmatrix(cov::Number; d) = PSDMatrix(sqrt(cov) * Eye(d)) cov2psdmatrix(cov::UniformScaling; d) = PSDMatrix(sqrt(cov.λ) * Eye(d)) +cov2psdmatrix(cov::Diagonal{<:Number,<:FillArrays.Fill}; d) = + (@assert size(cov, 1) == size(cov, 2) == d; cov2psdmatrix(cov.diag.value; d)) cov2psdmatrix(cov::Diagonal; d) = (@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(sqrt.(cov))) cov2psdmatrix(cov::AbstractMatrix; d) = diff --git a/src/blockdiagonals.jl b/src/blockdiagonals.jl index 9714ebd37..240e9d10e 100644 --- a/src/blockdiagonals.jl +++ b/src/blockdiagonals.jl @@ -270,8 +270,27 @@ for _mul! in (:mul!, :_matmul!) end return C end + @eval $_mul!(C::BlockDiag, A::BlockDiag, B::Diagonal, alpha::Number, beta::Number) = begin + local i = 1 + map(zip(blocks(C), blocks(A))) do (Ci, Ai) + d = size(Ai, 2) + $_mul!(Ci, Ai, Diagonal(view(B.diag, i:(i+d-1))), alpha, beta) + i += d + end + return C + end + @eval $_mul!(C::BlockDiag, A::Diagonal, B::BlockDiag, alpha::Number, beta::Number) = begin + local i = 1 + map(zip(blocks(C), blocks(B))) do (Ci, Bi) + d = size(Bi, 1) + $_mul!(Ci, Diagonal(view(A.diag, i:(i+d-1))), Bi, alpha, beta) + i += d + end + return C + end end + Base.isequal(A::BlockDiag, B::BlockDiag) = length(A.blocks) == length(B.blocks) && all(map(isequal, A.blocks, B.blocks)) ==(A::BlockDiag, B::BlockDiag) = diff --git a/src/callbacks/dataupdate.jl b/src/callbacks/dataupdate.jl index c14223cab..67823fdcc 100644 --- a/src/callbacks/dataupdate.jl +++ b/src/callbacks/dataupdate.jl @@ -52,30 +52,26 @@ function DataUpdateCallback( o = length(val) @unpack x, E0, m_tmp, G1 = integ.cache - H = view(G1, 1:o, :) - if observation_matrix === I - @.. H = E0 - elseif observation_matrix isa UniformScaling - @.. H = observation_matrix.λ * E0 - else - matmul!(H, observation_matrix, E0) - end + M = observation_matrix + H = M * E0 obs_mean = _matmul!(view(m_tmp.μ, 1:o), H, x.μ) obs_mean .-= val R = cov2psdmatrix(observation_noise_cov; d=o) + R = to_factorized_matrix(integ.cache.covariance_factorization, R) # _A = x.Σ.R * H' # obs_cov = _A'_A + R - obs_cov = PSDMatrix(qr!([x.Σ.R * H'; R.R]).R) + obs_cov = PSDMatrix(make_obscov_sqrt(x.Σ.R, H, R.R)) + obs = Gaussian(obs_mean, obs_cov) @unpack x_tmp, K1, C_DxD, C_dxd, C_Dxd, C_d = integ.cache - K1 = view(K1, :, 1:o) - C_dxd = view(C_dxd, 1:o, 1:o) - C_Dxd = view(C_Dxd, :, 1:o) - C_d = view(C_d, 1:o) + K1 = K1 * M' + C_dxd = M * C_dxd * M' + C_Dxd = C_Dxd * M' + C_d = M * C_d _x = copy!(x_tmp, x) _, ll = update!(x, _x, obs, H, K1, C_Dxd, C_DxD, C_dxd, C_d; R=R) @@ -85,3 +81,14 @@ function DataUpdateCallback( end return PresetTimeCallback(data.t, affect!; save_positions, kwargs...) end + +make_obscov_sqrt(PR::AbstractMatrix, H::AbstractMatrix, RR::AbstractMatrix) = + qr!([PR * H'; RR]).R +make_obscov_sqrt( + PR::IsometricKroneckerProduct, + H::IsometricKroneckerProduct, + RR::IsometricKroneckerProduct, +) = + IsometricKroneckerProduct(PR.ldim, make_obscov_sqrt(PR.B, H.B, RR.B)) +make_obscov_sqrt(PR::BlockDiag, H::BlockDiag, RR::BlockDiag) = + BlockDiag([make_obscov_sqrt(blocks(PR)[i], blocks(H)[i], blocks(RR)[i]) for i in eachindex(blocks(PR))]) diff --git a/src/covariance_structure.jl b/src/covariance_structure.jl index 234a77688..41c003e71 100644 --- a/src/covariance_structure.jl +++ b/src/covariance_structure.jl @@ -47,6 +47,15 @@ to_factorized_matrix(::DenseCovariance, M::AbstractMatrix) = Matrix(M) to_factorized_matrix(::IsometricKroneckerCovariance, M::IsometricKroneckerProduct) = M to_factorized_matrix(C::BlockDiagonalCovariance, M::IsometricKroneckerProduct) = BlockDiag([copy(M.B) for _ in 1:C.d]) +to_factorized_matrix(C::BlockDiagonalCovariance, M::Diagonal) = + copy!(factorized_similar(C, size(M)...), M) +to_factorized_matrix( + C::IsometricKroneckerCovariance, M::Diagonal{<:Number, <:FillArrays.Fill}) = begin + out = factorized_similar(C, size(M)...) + @assert length(out.B) == 1 + out.B .= M.diag.value + out + end for FT in [:DenseCovariance, :IsometricKroneckerCovariance, :BlockDiagonalCovariance] @eval to_factorized_matrix(FAC::$FT, M::PSDMatrix) = diff --git a/src/data_likelihoods/fenrir.jl b/src/data_likelihoods/fenrir.jl index a805f9de9..05e1c3fd6 100644 --- a/src/data_likelihoods/fenrir.jl +++ b/src/data_likelihoods/fenrir.jl @@ -57,6 +57,7 @@ function fenrir_data_loglik( # Fit the ODE solution / PN posterior to the provided data; this is the actual Fenrir o = length(data.u[1]) R = cov2psdmatrix(observation_noise_cov; d=o) + R = to_factorized_matrix(integ.cache.covariance_factorization, R) LL, _, _ = fit_pnsolution_to_data!(sol, R, data; proj=observation_matrix) return LL @@ -78,12 +79,12 @@ function fit_pnsolution_to_data!( _cache = ( x_tmp=x_tmp, C_DxD=C_DxD, - C_Dxd=view(C_Dxd, :, 1:o), - C_dxd=view(C_dxd, 1:o, 1:o), - C_d=view(C_d, 1:o), - K1=view(K1, :, 1:o), - K2=view(C_Dxd, :, 1:o), - m_tmp=Gaussian(view(m_tmp.μ, 1:o), view(m_tmp.Σ, 1:o, 1:o)), + C_Dxd=C_Dxd * proj', + C_dxd=proj * C_dxd * proj', + C_d=proj * C_d, + K1=K1 * proj', + K2=C_Dxd * proj', + m_tmp=proj * m_tmp, ) x_posterior = copy(sol.x_filt) # the object to be filled @@ -136,9 +137,7 @@ function measure_and_update!(x, u, H, R::PSDMatrix, cache) z, S = cache.m_tmp _matmul!(z, H, x.μ) z .-= u - _matmul!(cache.C_Dxd, x.Σ.R, H') - _matmul!(S, cache.C_Dxd', cache.C_Dxd) - S .+= _matmul!(cache.C_dxd, R.R', R.R) + S = PSDMatrix(make_obscov_sqrt(x.Σ.R, H, R.R)) msmnt = Gaussian(z, S) return update!(x, copy!(cache.x_tmp, x), msmnt, H; R=R, cache) diff --git a/src/filtering/update.jl b/src/filtering/update.jl index 3b07b0a45..cf10c6df1 100644 --- a/src/filtering/update.jl +++ b/src/filtering/update.jl @@ -230,7 +230,7 @@ function update!( M_cache.blocks[i], C_dxd.blocks[i], view(C_d, i:i); - R, + R=isnothing(R) ? nothing : PSDMatrix(blocks(R.R)[i]) ) ll += _ll end diff --git a/test/data_likelihoods.jl b/test/data_likelihoods.jl index dacb2b6fd..2a25c09db 100644 --- a/test/data_likelihoods.jl +++ b/test/data_likelihoods.jl @@ -33,25 +33,27 @@ kwargs = ( ) @testset "Compare data likelihoods" begin @testset "$alg" for alg in ( + # EK0 + EK0(), + EK0(diffusionmodel=FixedDiffusion()), + EK0(prior=IOUP(3, -1)), + EK0(prior=Matern(3, 1.5)), + # EK1 EK1(), EK1(diffusionmodel=FixedDiffusion()), # EK1(diffusionmodel=FixedMVDiffusion(rand(2), false)), # not yet supported EK1(prior=IOUP(3, -1)), EK1(prior=Matern(3, 1.5)), EK1(prior=IOUP(3, update_rate_parameter=true)), + # DiagonalEK1 + DiagonalEK1(), + DiagonalEK1(diffusionmodel=FixedDiffusion()), + DiagonalEK1(diffusionmodel=FixedMVDiffusion(rand(2), false)), ) compare_data_likelihoods(alg; kwargs...) end end -@testset "EK0 is not (yet) supported" begin - for ll in (PNDE.dalton_data_loglik, PNDE.filtering_data_loglik) - @test_broken ll(prob, EK0(smooth=false); kwargs...) - end - @test_broken PNDE.fenrir_data_loglik( - prob, EK0(smooth=true); kwargs...) -end - @testset "Partial observations" begin H = [1 0;] data_part = (t=times, u=[H * d for d in obss]) @@ -63,6 +65,14 @@ end adaptive=false, dt=DT, dense=false, ) + @test_broken compare_data_likelihoods( + DiagonalEK1(); + observation_matrix=H, + observation_noise_cov=σ^2, + data=data_part, + adaptive=false, dt=DT, + dense=false, + ) end @testset "Observation noise types: $(typeof(Σ))" for Σ in ( @@ -70,15 +80,24 @@ end σ^2 * I, σ^2 * I(2), σ^2 * Eye(2), + Diagonal([σ^2 0; 0 2σ^2]), [σ^2 0; 0 2σ^2], (A = randn(2, 2); A'A), (PSDMatrix(randn(2, 2))), ) - compare_data_likelihoods( - EK1(); - observation_noise_cov=Σ, - data=data, - adaptive=false, dt=DT, - dense=false, - ) + @testset "$alg" for alg in (EK0(), DiagonalEK1(), EK1()) + if alg isa EK0 && !(Σ isa Number || Σ isa UniformScaling || Σ isa Diagonal{<:Number,<:FillArrays.Fill}) + continue + end + if alg isa DiagonalEK1 && !(Σ isa Number || Σ isa UniformScaling || Σ isa Diagonal) + continue + end + compare_data_likelihoods( + alg; + observation_noise_cov=Σ, + data=data, + adaptive=false, dt=DT, + dense=false, + ) + end end