Skip to content

Commit

Permalink
Make update! compute the log-likelihood too
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Jan 4, 2024
1 parent 09fb1fd commit 0cd7654
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 14 deletions.
8 changes: 5 additions & 3 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mutable struct EKCache{
RType,CFacType,ProjType,SolProjType,PType,PIType,EType,uType,duType,xType,PriorType,
AType,QType,
FType,LType,FHGMethodType,FHGCacheType,
HType,matType,bkType,diffusionType,diffModelType,measModType,measType,
HType,vecType,matType,bkType,diffusionType,diffModelType,measModType,measType,
puType,llType,dtType,rateType,UF,JC,uNoUnitsType,
} <: AbstractODEFilterCache
# Constants
Expand Down Expand Up @@ -53,6 +53,7 @@ mutable struct EKCache{
K1::matType
G1::matType
Smat::HType
C_d::vecType
C_dxd::matType
C_dxD::matType
C_Dxd::matType
Expand Down Expand Up @@ -187,6 +188,7 @@ function OrdinaryDiffEq.alg_cache(
G = factorized_similar(FAC, D, D)
Smat = factorized_similar(FAC, d, d)

C_d = similar(Array{uElType}, d)
C_dxd = factorized_similar(FAC, d, d)
C_dxD = factorized_similar(FAC, d, D)
C_Dxd = factorized_similar(FAC, D, d)
Expand Down Expand Up @@ -227,7 +229,7 @@ function OrdinaryDiffEq.alg_cache(
typeof(R),typeof(FAC),typeof(Proj),typeof(SolProj),typeof(P),typeof(PI),typeof(E0),
uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q),
typeof(F),typeof(L),typeof(FHG_method),typeof(FHG_cache),
typeof(H),matType,
typeof(H),typeof(C_d),matType,
typeof(backward_kernel),typeof(initdiff),
typeof(diffmodel),typeof(measurement_model),typeof(measurement),typeof(pu_tmp),
uEltypeNoUnits,typeof(dt),typeof(du1),typeof(uf),typeof(jac_config),typeof(atmp),
Expand All @@ -239,7 +241,7 @@ function OrdinaryDiffEq.alg_cache(
x0, xprev, x_pred, x_filt, x_tmp, x_tmp2,
measurement, m_tmp, pu_tmp,
H, du, ddu, K, G, Smat,
C_dxd, C_dxD, C_Dxd, C_DxD, C_2DxD, C_3DxD,
C_d, C_dxd, C_dxD, C_Dxd, C_DxD, C_2DxD, C_3DxD,
backward_kernel,
initdiff, initdiff * NaN, initdiff * NaN,
err_tmp, ll, dt, du1, uf, jac_config,
Expand Down
23 changes: 19 additions & 4 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function update!(
K2_cache::AbstractMatrix,
M_cache::AbstractMatrix,
C_dxd::AbstractMatrix,
C_d::AbstractVector,
)
z, S = measurement.μ, measurement.Σ
m_p, P_p = x_pred.μ, x_pred.Σ
Expand Down Expand Up @@ -98,7 +99,11 @@ function update!(
_matmul!(K2_cache, P_p, H')
end

rdiv!(K, length(_S) == 1 ? _S[1] : cholesky!(_S))
S_chol = length(_S) == 1 ? _S[1] : cholesky!(_S)
rdiv!(K, S_chol)

loglikelihood = zero(eltype(K))
loglikelihood = pn_logpdf!(measurement, S_chol, C_d)

# x_out.μ .= m_p .+ K * (0 .- z)
x_out.μ .= m_p .- _matmul!(x_out.μ, K, z)
Expand All @@ -111,7 +116,16 @@ function update!(

fast_X_A_Xt!(x_out.Σ, P_p, M_cache)

return x_out
return x_out, loglikelihood
end
function pn_logpdf!(measurement, S_chol, tmpmean)
μ = measurement.μ
Σ = S_chol

d = length(μ)
z = ldiv!(Σ, copy!(tmpmean, μ))

return -0.5 * μ'z - 0.5 * d * log(2π) - 0.5 * logdet(Σ)
end

# Kronecker version
Expand All @@ -124,6 +138,7 @@ function update!(
K2_cache::IsometricKroneckerProduct,
M_cache::IsometricKroneckerProduct,
C_dxd::IsometricKroneckerProduct,
C_d::AbstractVector,
) where {T}
D = length(x_out.μ) # full_state_dim
d = H.ldim # ode_dimension_dim
Expand All @@ -138,6 +153,6 @@ function update!(
_M_cache = M_cache.B
_C_dxd = C_dxd.B

update!(_x_out, _x_pred, _measurement, _H, _K1_cache, _K2_cache, _M_cache, _C_dxd)
return x_out
_, loglikelihood = update!(_x_out, _x_pred, _measurement, _H, _K1_cache, _K2_cache, _M_cache, _C_dxd, C_d)
return x_out, loglikelihood
end
4 changes: 2 additions & 2 deletions src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ function init_condition_on!(
data::AbstractVector,
cache,
)
@unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, m_tmp = cache
@unpack x_tmp, K1, C_Dxd, C_DxD, C_dxd, m_tmp, C_d = cache

# measurement mean
_matmul!(m_tmp.μ, H, x.μ)
Expand All @@ -130,5 +130,5 @@ function init_condition_on!(
# measurement cov
fast_X_A_Xt!(m_tmp.Σ, x.Σ, H)
copy!(x_tmp, x)
update!(x, x_tmp, m_tmp, H, K1, C_Dxd, C_DxD, C_dxd)
update!(x, x_tmp, m_tmp, H, K1, C_Dxd, C_DxD, C_dxd, C_d)
end
11 changes: 6 additions & 5 deletions src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ function OrdinaryDiffEq.perform_step!(integ, cache::EKCache, repeat_step=false)

# Compute measurement covariance only now; likelihood computation is currently broken
compute_measurement_covariance!(cache)
cache.log_likelihood = logpdf(cache.measurement, zeros(d))
integ.sol.log_likelihood += cache.log_likelihood

# Update state and save the ODE solution value
x_filt = update!(cache, x_pred)
x_filt, loglikelihood = update!(cache, x_pred)
write_into_solution!(
integ.u, x_filt.μ; cache, is_secondorder_ode=integ.f isa DynamicalODEFunction)

cache.log_likelihood = loglikelihood
integ.sol.log_likelihood += cache.log_likelihood

# Update the global diffusion MLE (if applicable)
if !isdynamic(cache.diffusionmodel)
cache.global_diffusion = estimate_global_diffusion(cache.diffusionmodel, integ)
Expand Down Expand Up @@ -163,10 +164,10 @@ compute_measurement_covariance!(cache) =

function update!(cache, prediction)
@unpack measurement, H, x_filt, K1, m_tmp, C_DxD = cache
@unpack C_dxd, C_Dxd = cache
@unpack C_dxd, C_Dxd, C_d = cache
K2 = C_Dxd

return update!(x_filt, prediction, measurement, H, K1, K2, C_DxD, C_dxd)
return update!(x_filt, prediction, measurement, H, K1, K2, C_DxD, C_dxd, C_d)
end

"""
Expand Down

0 comments on commit 0cd7654

Please sign in to comment.