Skip to content

Commit

Permalink
Make the data likelihoods better
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 17, 2024
1 parent 786b1ed commit 29be0c0
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 30 deletions.
20 changes: 14 additions & 6 deletions src/blockdiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ end
for _mul! in (:mul!, :_matmul!)
@eval $_mul!(C::BlockDiag, A::BlockDiag, B::Diagonal) = begin
local i = 1
map(zip(blocks(C), blocks(A))) do (Ci, Ai)
@assert nblocks(C) == nblocks(A)
for j in eachindex(blocks(C))
Ci, Ai = blocks(C)[j], blocks(A)[j]
d = size(Ai, 2)
$_mul!(Ci, Ai, Diagonal(view(B.diag, i:(i+d-1))))
i += d
Expand All @@ -263,7 +265,9 @@ for _mul! in (:mul!, :_matmul!)
end
@eval $_mul!(C::BlockDiag, A::Diagonal, B::BlockDiag) = begin
local i = 1
map(zip(blocks(C), blocks(B))) do (Ci, Bi)
@assert nblocks(C) == nblocks(B)
for j in eachindex(blocks(C))
Ci, Bi = blocks(C)[j], blocks(B)[j]
d = size(Bi, 1)
$_mul!(Ci, Diagonal(view(A.diag, i:(i+d-1))), Bi)
i += d
Expand All @@ -272,18 +276,22 @@ for _mul! in (:mul!, :_matmul!)
end
@eval $_mul!(C::BlockDiag, A::BlockDiag, B::Diagonal, alpha::Number, beta::Number) = begin
local i = 1
map(zip(blocks(C), blocks(A))) do (Ci, Ai)
@assert nblocks(C) == nblocks(A)
for j in eachindex(blocks(C))
Ci, Ai = blocks(C)[j], blocks(A)[j]
d = size(Ai, 2)
$_mul!(Ci, Ai, Diagonal(view(B.diag, i:(i+d-1))), alpha, beta)
i += d
end
return C
end
@eval $_mul!(C::BlockDiag, A::Diagonal, B::BlockDiag, alpha::Number, beta::Number) = begin
local i = 1
map(zip(blocks(C), blocks(B))) do (Ci, Bi)
i = 1
@assert nblocks(C) == nblocks(B)
for j in eachindex(blocks(C))
Ci, Bi = blocks(C)[j], blocks(B)[j]
d = size(Bi, 1)
$_mul!(Ci, Diagonal(view(A.diag, i:(i+d-1))), Bi, alpha, beta)
@inbounds $_mul!(Ci, Diagonal(view(A.diag, i:(i+d-1))), Bi, alpha, beta)
i += d
end
return C
Expand Down
40 changes: 33 additions & 7 deletions src/callbacks/dataupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function DataUpdateCallback(
val = values[idx]

o = length(val)
d = integ.cache.d

@unpack x, E0, m_tmp, G1 = integ.cache
M = observation_matrix
Expand All @@ -67,12 +68,16 @@ function DataUpdateCallback(

obs = Gaussian(obs_mean, obs_cov)

@unpack x_tmp, K1, C_DxD, C_dxd, C_Dxd, C_d = integ.cache
K1 = K1 * M'
C_dxd = M * C_dxd * M'
C_Dxd = C_Dxd * M'
C_d = M * C_d
_x = copy!(x_tmp, x)
_cache = if o != d
if !(integ.alg isa EK1)
error("Partial observations only work with the EK1 right now")
end
make_obssized_cache(integ.cache; o)
else
integ.cache
end
@unpack K1, C_DxD, C_dxd, C_Dxd, C_d = _cache
_x = copy!(integ.cache.x_tmp, x)
_, ll = update!(x, _x, obs, H, K1, C_Dxd, C_DxD, C_dxd, C_d; R=R)

if !isnothing(loglikelihood)
Expand All @@ -91,4 +96,25 @@ make_obscov_sqrt(
) =
IsometricKroneckerProduct(PR.ldim, make_obscov_sqrt(PR.B, H.B, RR.B))
make_obscov_sqrt(PR::BlockDiag, H::BlockDiag, RR::BlockDiag) =
BlockDiag([make_obscov_sqrt(blocks(PR)[i], blocks(H)[i], blocks(RR)[i]) for i in eachindex(blocks(PR))])
BlockDiag([
make_obscov_sqrt(blocks(PR)[i], blocks(H)[i], blocks(RR)[i]) for
i in eachindex(blocks(PR))
])

function make_obssized_cache(cache; o)
@unpack K1, C_DxD, C_dxd, C_Dxd, C_d, m_tmp, x_tmp = cache
return make_obssized_cache(K1, C_DxD, C_dxd, C_Dxd, C_d, m_tmp, x_tmp; o)
end
function make_obssized_cache(
K1::M, C_DxD::M, C_dxd::M, C_Dxd::M, C_d::V, m_tmp::G, x_tmp; o,
) where {M<:Matrix,V<:Vector,G<:Gaussian}
return (
K1=view(K1, :, 1:o),
C_dxd=view(C_dxd, 1:o, 1:o),
C_Dxd=view(C_Dxd, :, 1:o),
C_d=view(C_d, 1:o),
C_DxD=C_DxD,
m_tmp=Gaussian(view(m_tmp.μ, 1:o), view(m_tmp.Σ, 1:o, 1:o)),
x_tmp=x_tmp,
)
end
2 changes: 1 addition & 1 deletion src/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function check_densesmooth(integ)
error("To use `dense=true` you need to set `smooth=true`!")
end
if !integ.opts.save_everystep && integ.alg.smooth
error("If you do not save all values, you do not need to smooth!")
error("If you set `save_everystep=false` also set `smooth=false` in the alg!")
end
end
function check_saveiter(integ)
Expand Down
19 changes: 8 additions & 11 deletions src/data_likelihoods/fenrir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,14 @@ function fit_pnsolution_to_data!(
LL = zero(eltype(sol.prob.p))

o = length(data.u[1])
@unpack x_tmp, C_dxd, C_d, K1, C_Dxd, C_DxD, m_tmp = cache
_cache = (
x_tmp=x_tmp,
C_DxD=C_DxD,
C_Dxd=C_Dxd * proj',
C_dxd=proj * C_dxd * proj',
C_d=proj * C_d,
K1=K1 * proj',
K2=C_Dxd * proj',
m_tmp=proj * m_tmp,
)
d = cache.d
@unpack x_tmp, m_tmp = cache
_cache = if o != d
make_obssized_cache(cache; o)
else
cache
end
@unpack K1, C_DxD, C_dxd, C_Dxd, C_d = _cache

x_posterior = copy(sol.x_filt) # the object to be filled
state2data_projmat = proj * cache.SolProj
Expand Down
3 changes: 2 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ function calc_H!(H, integ, cache)
elseif integ.alg isa DiagonalEK1
calc_H_EK0!(H, integ, cache)
OrdinaryDiffEq.calc_J!(ddu, integ, cache, true)
_matmul!(H, Diagonal(ddu), cache.SolProj, -1.0, 1.0)
ddu_diag = Diagonal(ddu)
_matmul!(H, ddu_diag, cache.SolProj, -1.0, 1.0)
else
error("Unknown algorithm")
end
Expand Down
8 changes: 4 additions & 4 deletions src/priors/iwp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ end
function initialize_transition_matrices(FAC::IsometricKroneckerCovariance, p::IWP, dt)
A, Q = preconditioned_discretize(p)
P, PI = initialize_preconditioner(FAC, p, dt)
Ah = PI * A * P
Qh = PSDMatrix(Q.R * PI)
Ah = copy(A)
Qh = copy(Q)
return A, Q, Ah, Qh, P, PI
end
function initialize_transition_matrices(FAC::DenseCovariance, p::IWP, dt)
Expand All @@ -174,8 +174,8 @@ function initialize_transition_matrices(FAC::BlockDiagonalCovariance, p::IWP, dt
A = to_factorized_matrix(FAC, A)
Q = to_factorized_matrix(FAC, Q)
P, PI = initialize_preconditioner(FAC, p, dt)
Ah = PI * A * P
Qh = PSDMatrix(Q.R * PI)
Ah = copy(A)
Qh = copy(Q)
return A, Q, Ah, Qh, P, PI
end

Expand Down

0 comments on commit 29be0c0

Please sign in to comment.