From c430bb85c250c566e598704af132af30b18aa845 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Fri, 9 Feb 2024 15:50:52 +0000 Subject: [PATCH] Implement observation noise for the PN likelihood (#299) * Implement observation noise for the PN likelihood * Fix some issues that this created with the EK0 * Add check to make sure that calibration is off when R>0 * Use cov2psdmatrix also in the data likelihoods * Slightly more verbose error if R>0 and calibrate=true * More compact code in caches.jl * Add a warning whenever we use `triangularize!` just to be more aware of what acually happens * Make multivariate diffusions illegal with the EK1 * Revisit some diffusion-related stuff in diffusions.jl * Fix a bug with classicsolverinit * Make update compatible with PSDMatrix-values marginal obs covs again * Properly fix the RKinit bug * JuliaFormatter.jl * Fix more tests * JuliaFormatter.jl _again_ * Algorithm check was bad; this fixes it * JuliaFormatter.jl --- src/ProbNumDiffEq.jl | 8 ++ src/algorithms.jl | 103 +++++++++++++++--------- src/caches.jl | 6 +- src/callbacks/dataupdate.jl | 10 +-- src/data_likelihoods/fenrir.jl | 20 ++--- src/diffusions.jl | 39 +++++---- src/filtering/update.jl | 22 ++++- src/initialization/classicsolverinit.jl | 3 +- src/initialization/common.jl | 3 +- src/perform_step.jl | 12 ++- 10 files changed, 135 insertions(+), 91 deletions(-) diff --git a/src/ProbNumDiffEq.jl b/src/ProbNumDiffEq.jl index 8a32517c2..eac2a17cb 100644 --- a/src/ProbNumDiffEq.jl +++ b/src/ProbNumDiffEq.jl @@ -43,6 +43,14 @@ X_A_Xt(A, X) = X * A * X' stack(x) = copy(reduce(hcat, x)') vecvec2mat(x) = reduce(hcat, x)' +cov2psdmatrix(cov::Number; d) = PSDMatrix(sqrt(cov) * Eye(d)) +cov2psdmatrix(cov::UniformScaling; d) = PSDMatrix(sqrt(cov.λ) * Eye(d)) +cov2psdmatrix(cov::Diagonal; d) = + (@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(sqrt.(cov))) +cov2psdmatrix(cov::AbstractMatrix; d) = + (@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(Matrix(cholesky(cov).U))) +cov2psdmatrix(cov::PSDMatrix; d) = (@assert size(cov, 1) == size(cov, 2) == d; cov) + include("fast_linalg.jl") include("kronecker.jl") include("covariance_structure.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index d5c14766b..9982f0f5d 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -3,6 +3,26 @@ ######################################################################################## abstract type AbstractEK <: OrdinaryDiffEq.OrdinaryDiffEqAdaptiveAlgorithm end +function ekargcheck(alg; diffusionmodel, pn_observation_noise, kwargs...) + if (isstatic(diffusionmodel) && diffusionmodel.calibrate) && + (!isnothing(pn_observation_noise) && !iszero(pn_observation_noise)) + throw( + ArgumentError( + "Automatic calibration of global diffusion models is not possible when using observation noise. If you want to calibrate a global diffusion parameter, do so setting `calibrate=false` and optimizing `sol.pnstats.log_likelihood` manually.", + ), + ) + end + if ( + (diffusionmodel isa FixedMVDiffusion && diffusionmodel.calibrate) || + diffusionmodel isa DynamicMVDiffusion) && alg == EK1 + throw( + ArgumentError( + "The `EK1` algorithm does not support automatic calibration of multivariate diffusion models. Either use the `EK0` instead, or use a scalar diffusion model, or set `calibrate=false` and calibrate manually by optimizing `sol.pnstats.log_likelihood`.", + ), + ) + end +end + """ EK0(; order=3, smooth=true, @@ -38,19 +58,24 @@ julia> solve(prob, EK0()) # [References](@ref references) """ -struct EK0{PT,DT,IT} <: AbstractEK +struct EK0{PT,DT,IT,RT} <: AbstractEK prior::PT diffusionmodel::DT smooth::Bool initialization::IT + pn_observation_noise::RT + EK0(; order=3, + prior::PT=IWP(order), + diffusionmodel::DT=DynamicDiffusion(), + smooth=true, + initialization::IT=TaylorModeInit(num_derivatives(prior)), + pn_observation_noise::RT=nothing, + ) where {PT,DT,IT,RT} = begin + ekargcheck(EK0; diffusionmodel, pn_observation_noise) + new{PT,DT,IT,RT}( + prior, diffusionmodel, smooth, initialization, pn_observation_noise) + end end -EK0(; - order=3, - prior=IWP(order), - diffusionmodel=DynamicDiffusion(), - smooth=true, - initialization=TaylorModeInit(num_derivatives(prior)), -) = EK0(prior, diffusionmodel, smooth, initialization) _unwrap_val(::Val{B}) where {B} = B _unwrap_val(B) = B @@ -92,39 +117,45 @@ julia> solve(prob, EK1()) # [References](@ref references) """ -struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT} <: AbstractEK +struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK prior::PT diffusionmodel::DT smooth::Bool initialization::IT + pn_observation_noise::RT + EK1(; + order=3, + prior::PT=IWP(order), + diffusionmodel::DT=DynamicDiffusion(), + smooth=true, + initialization::IT=TaylorModeInit(num_derivatives(prior)), + chunk_size=Val{0}(), + autodiff=Val{true}(), + diff_type=Val{:forward}, + standardtag=Val{true}(), + concrete_jac=nothing, + pn_observation_noise::RT=nothing, + ) where {PT,DT,IT,RT} = begin + ekargcheck(EK1; diffusionmodel, pn_observation_noise) + new{ + _unwrap_val(chunk_size), + _unwrap_val(autodiff), + diff_type, + _unwrap_val(standardtag), + _unwrap_val(concrete_jac), + PT, + DT, + IT, + RT, + }( + prior, + diffusionmodel, + smooth, + initialization, + pn_observation_noise, + ) + end end -EK1(; - order=3, - prior::PT=IWP(order), - diffusionmodel::DT=DynamicDiffusion(), - smooth=true, - initialization::IT=TaylorModeInit(num_derivatives(prior)), - chunk_size=Val{0}(), - autodiff=Val{true}(), - diff_type=Val{:forward}, - standardtag=Val{true}(), - concrete_jac=nothing, -) where {PT,DT,IT} = - EK1{ - _unwrap_val(chunk_size), - _unwrap_val(autodiff), - diff_type, - _unwrap_val(standardtag), - _unwrap_val(concrete_jac), - PT, - DT, - IT, - }( - prior, - diffusionmodel, - smooth, - initialization, - ) """ ExpEK(; L, order=3, kwargs...) diff --git a/src/caches.jl b/src/caches.jl index 246e7f637..da8e191cd 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -168,10 +168,12 @@ function OrdinaryDiffEq.alg_cache( copy!(x0.Σ, apply_diffusion(x0.Σ, initdiff)) # Measurement model related things - R = nothing # factorized_similar(FAC, d, d) + R = + isnothing(alg.pn_observation_noise) ? nothing : + to_factorized_matrix(FAC, cov2psdmatrix(alg.pn_observation_noise; d)) H = factorized_similar(FAC, d, D) v = similar(Array{uElType}, d) - S = PSDMatrix(factorized_zeros(FAC, D, d)) + S = factorized_zeros(FAC, d, d) measurement = Gaussian(v, S) # Caches diff --git a/src/callbacks/dataupdate.jl b/src/callbacks/dataupdate.jl index 0c4bd5316..c14223cab 100644 --- a/src/callbacks/dataupdate.jl +++ b/src/callbacks/dataupdate.jl @@ -64,15 +64,7 @@ function DataUpdateCallback( obs_mean = _matmul!(view(m_tmp.μ, 1:o), H, x.μ) obs_mean .-= val - R = if observation_noise_cov isa PSDMatrix - observation_noise_cov - elseif observation_noise_cov isa Number - PSDMatrix(sqrt(observation_noise_cov) * Eye(o)) - elseif observation_noise_cov isa UniformScaling - PSDMatrix(sqrt(observation_noise_cov.λ) * Eye(o)) - else - PSDMatrix(cholesky(observation_noise_cov).U) - end + R = cov2psdmatrix(observation_noise_cov; d=o) # _A = x.Σ.R * H' # obs_cov = _A'_A + R diff --git a/src/data_likelihoods/fenrir.jl b/src/data_likelihoods/fenrir.jl index 4c507e60a..a805f9de9 100644 --- a/src/data_likelihoods/fenrir.jl +++ b/src/data_likelihoods/fenrir.jl @@ -56,15 +56,7 @@ function fenrir_data_loglik( # Fit the ODE solution / PN posterior to the provided data; this is the actual Fenrir o = length(data.u[1]) - R = if observation_noise_cov isa PSDMatrix - observation_noise_cov - elseif observation_noise_cov isa Number - PSDMatrix(sqrt(observation_noise_cov) * Eye(o)) - elseif observation_noise_cov isa UniformScaling - PSDMatrix(sqrt(observation_noise_cov.λ) * Eye(o)) - else - PSDMatrix(cholesky(observation_noise_cov).U) - end + R = cov2psdmatrix(observation_noise_cov; d=o) LL, _, _ = fit_pnsolution_to_data!(sol, R, data; proj=observation_matrix) return LL @@ -91,7 +83,7 @@ function fit_pnsolution_to_data!( C_d=view(C_d, 1:o), K1=view(K1, :, 1:o), K2=view(C_Dxd, :, 1:o), - m_tmp=Gaussian(view(m_tmp.μ, 1:o), PSDMatrix(view(m_tmp.Σ.R, :, 1:o))), + m_tmp=Gaussian(view(m_tmp.μ, 1:o), view(m_tmp.Σ, 1:o, 1:o)), ) x_posterior = copy(sol.x_filt) # the object to be filled @@ -144,10 +136,10 @@ function measure_and_update!(x, u, H, R::PSDMatrix, cache) z, S = cache.m_tmp _matmul!(z, H, x.μ) z .-= u - fast_X_A_Xt!(S, x.Σ, H) - # _S = PSDMatrix(S.R'S.R + R.R'R.R) - _S = PSDMatrix(triangularize!([S.R; R.R], cachemat=cache.C_DxD)) - msmnt = Gaussian(z, _S) + _matmul!(cache.C_Dxd, x.Σ.R, H') + _matmul!(S, cache.C_Dxd', cache.C_Dxd) + S .+= _matmul!(cache.C_dxd, R.R', R.R) + msmnt = Gaussian(z, S) return update!(x, copy!(cache.x_tmp, x), msmnt, H; R=R, cache) end diff --git a/src/diffusions.jl b/src/diffusions.jl index fbaf0e4d3..3570a22f7 100644 --- a/src/diffusions.jl +++ b/src/diffusions.jl @@ -69,13 +69,12 @@ function estimate_global_diffusion(::FixedDiffusion, integ) v, S = measurement.μ, measurement.Σ e = m_tmp.μ - _S = _matmul!(Smat, S.R', S.R) e .= v - diffusion_t = if _S isa IsometricKroneckerProduct - @assert length(_S.B) == 1 - dot(v, e) / d / _S.B[1] + diffusion_t = if S isa IsometricKroneckerProduct + @assert length(S.B) == 1 + dot(v, e) / d / S.B[1] else - S_chol = cholesky!(_S) + S_chol = cholesky!(S) ldiv!(S_chol, e) dot(v, e) / d end @@ -123,13 +122,12 @@ function estimate_global_diffusion(::FixedMVDiffusion, integ) @unpack d, q, measurement, local_diffusion = integ.cache v, S = measurement.μ, measurement.Σ - # S_11 = diag(S)[1] - c1 = view(S.R, :, 1) - S_11 = dot(c1, c1) + # @assert diag(S) |> unique |> length == 1 + S_11 = S[1, 1] Σ_ii = v .^ 2 ./ S_11 Σ = Diagonal(Σ_ii) - Σ_out = kron(Σ, I(q + 1)) + Σ_out = kron(Σ, I(q + 1)) # -> Different for each dimension; same for each derivative if integ.success_iter == 0 # @assert length(diffusions) == 0 @@ -159,17 +157,17 @@ For more background information * [bosch20capos](@cite) Bosch et al, "Calibrated Adaptive Probabilistic ODE Solvers", AISTATS (2021) """ function local_scalar_diffusion(cache) - @unpack d, R, H, Qh, measurement, m_tmp, Smat = cache + @unpack d, R, H, Qh, measurement, m_tmp, Smat, C_Dxd = cache z = measurement.μ e, HQH = m_tmp.μ, m_tmp.Σ - fast_X_A_Xt!(HQH, Qh, H) - HQHmat = _matmul!(Smat, HQH.R', HQH.R) + _matmul!(C_Dxd, Qh.R, H') + _matmul!(HQH, C_Dxd', C_Dxd) e .= z - σ² = if HQHmat isa IsometricKroneckerProduct - @assert length(HQHmat.B) == 1 - dot(z, e) / d / HQHmat.B[1] + σ² = if HQH isa IsometricKroneckerProduct + @assert length(HQH.B) == 1 + dot(z, e) / d / HQH.B[1] else - C = cholesky!(HQHmat) + C = cholesky!(HQH) ldiv!(C, e) dot(z, e) / d end @@ -195,16 +193,17 @@ function local_diagonal_diffusion(cache) @unpack d, q, H, Qh, measurement, m_tmp, tmp = cache @unpack local_diffusion = cache z = measurement.μ - HQH = fast_X_A_Xt!(m_tmp.Σ, Qh, H) - # Q0_11 = diag(HQH)[1] - c1 = view(HQH.R, :, 1) + # HQH = H * unfactorize(Qh) * H' + # @assert HQH |> diag |> unique |> length == 1 + # c1 = view(_matmul!(cache.C_Dxd, Qh.R, H'), :, 1) + c1 = mul!(view(cache.C_Dxd, :, 1:1), Qh.R, view(H, 1:1, :)') Q0_11 = dot(c1, c1) Σ_ii = @. m_tmp.μ = z^2 / Q0_11 - # Σ_ii .= max.(Σ_ii, eps(eltype(Σ_ii))) Σ = Diagonal(Σ_ii) # local_diffusion = kron(Σ, I(q+1)) + # -> Different for each dimension; same for each derivative for i in 1:d for j in (i-1)*(q+1)+1:i*(q+1) local_diffusion[j, j] = Σ[i, i] diff --git a/src/filtering/update.jl b/src/filtering/update.jl index b5be08cde..3d9125fcd 100644 --- a/src/filtering/update.jl +++ b/src/filtering/update.jl @@ -122,7 +122,16 @@ function update!( fast_X_A_Xt!(x_out.Σ, P_p, M_cache) if !isnothing(R) - x_out.Σ.R .= triangularize!([x_out.Σ.R; R.R * K']; cachemat=M_cache) + # M = Matrix(x_out.Σ) + K * Matrix(R) * K' + _matmul!(M_cache, x_out.Σ.R', x_out.Σ.R) + _matmul!(K1_cache, K, R.R') + _matmul!(M_cache, K1_cache, K1_cache', 1, 1) + chol = cholesky!(Symmetric(M_cache), check=false) + if issuccess(chol) + copy!(x_out.Σ.R, chol.U) + else + x_out.Σ.R .= triangularize!([x_out.Σ.R; K1_cache']; cachemat=M_cache) + end end return x_out, loglikelihood @@ -141,7 +150,10 @@ end function update!( x_out::SRGaussian{T,<:IsometricKroneckerProduct}, x_pred::SRGaussian{T,<:IsometricKroneckerProduct}, - measurement::SRGaussian{T,<:IsometricKroneckerProduct}, + measurement::Gaussian{ + <:AbstractVector, + <:Union{<:PSDMatrix{T,<:IsometricKroneckerProduct},<:IsometricKroneckerProduct}, + }, H::IsometricKroneckerProduct, K1_cache::IsometricKroneckerProduct, K2_cache::IsometricKroneckerProduct, @@ -156,7 +168,9 @@ function update!( _x_out = Gaussian(reshape_no_alloc(x_out.μ, Q, d), PSDMatrix(x_out.Σ.R.B)) _x_pred = Gaussian(reshape_no_alloc(x_pred.μ, Q, d), PSDMatrix(x_pred.Σ.R.B)) _measurement = Gaussian( - reshape_no_alloc(measurement.μ, 1, d), PSDMatrix(measurement.Σ.R.B)) + reshape_no_alloc(measurement.μ, 1, d), + measurement.Σ isa PSDMatrix ? PSDMatrix(measurement.Σ.R.B) : measurement.Σ.B, + ) _H = H.B _K1_cache = K1_cache.B _K2_cache = K2_cache.B @@ -180,7 +194,7 @@ function update!( end # Short-hand with cache -function update!(x_out, x, measurement, H; R=nothing, cache) +function update!(x_out, x, measurement, H; cache, R=nothing) @unpack K1, m_tmp, C_DxD, C_dxd, C_Dxd, C_d = cache K2 = C_Dxd return update!(x_out, x, measurement, H, K1, K2, C_DxD, C_dxd, C_d; R) diff --git a/src/initialization/classicsolverinit.jl b/src/initialization/classicsolverinit.jl index 2d09fb572..e7bd1547f 100644 --- a/src/initialization/classicsolverinit.jl +++ b/src/initialization/classicsolverinit.jl @@ -119,7 +119,8 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt) H = cache.E0 * PI measurement.μ .= H * x_pred.μ .- u - fast_X_A_Xt!(measurement.Σ, x_pred.Σ, H) + _matmul!(C_Dxd, x_pred.Σ.R, H') + _matmul!(measurement.Σ, C_Dxd', C_Dxd) update!(x_filt, x_pred, measurement, H, K1, C_Dxd, C_DxD, C_dxd, C_d) push!(filts, copy(x_filt)) diff --git a/src/initialization/common.jl b/src/initialization/common.jl index 58bf27741..0d6909053 100644 --- a/src/initialization/common.jl +++ b/src/initialization/common.jl @@ -128,7 +128,8 @@ function init_condition_on!( m_tmp.μ .-= data # measurement cov - fast_X_A_Xt!(m_tmp.Σ, x.Σ, H) + _matmul!(C_Dxd, x.Σ.R, H') + _matmul!(m_tmp.Σ, C_Dxd', C_Dxd) copy!(x_tmp, x) update!(x, x_tmp, m_tmp, H, K1, C_Dxd, C_DxD, C_dxd, C_d) end diff --git a/src/perform_step.jl b/src/perform_step.jl index 005c20da9..29574b68a 100644 --- a/src/perform_step.jl +++ b/src/perform_step.jl @@ -114,7 +114,8 @@ function OrdinaryDiffEq.perform_step!(integ, cache::EKCache, repeat_step=false) compute_measurement_covariance!(cache) # Update state and save the ODE solution value - x_filt, loglikelihood = update!(x_filt, x_pred, cache.measurement, cache.H; cache) + x_filt, loglikelihood = update!( + x_filt, x_pred, cache.measurement, cache.H; cache, R=cache.R) write_into_solution!( integ.u, x_filt.μ; cache, is_secondorder_ode=integ.f isa DynamicalODEFunction) @@ -159,8 +160,11 @@ function evaluate_ode!(integ, x_pred, t) end compute_measurement_covariance!(cache) = begin - @assert isnothing(cache.R) - fast_X_A_Xt!(cache.measurement.Σ, cache.x_pred.Σ, cache.H) + _matmul!(cache.C_Dxd, cache.x_pred.Σ.R, cache.H') + _matmul!(cache.measurement.Σ, cache.C_Dxd', cache.C_Dxd) + if !isnothing(cache.R) + cache.measurement.Σ .+= _matmul!(cache.C_dxd, cache.R.R', cache.R.R) + end end """ @@ -216,7 +220,7 @@ To save allocations, the function modifies the given `cache` and writes into function estimate_errors!(cache::AbstractODEFilterCache) @unpack local_diffusion, Qh, H, d = cache - R = cache.measurement.Σ.R + R = cache.C_Dxd if local_diffusion isa Diagonal _QR = cache.C_DxD .= Qh.R .* sqrt.(local_diffusion.diag)'