Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement observation noise for the PN likelihood #299

Merged
merged 17 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ X_A_Xt(A, X) = X * A * X'
stack(x) = copy(reduce(hcat, x)')
vecvec2mat(x) = reduce(hcat, x)'

cov2psdmatrix(cov::Number; d) = PSDMatrix(sqrt(cov) * Eye(d))
cov2psdmatrix(cov::UniformScaling; d) = PSDMatrix(sqrt(cov.λ) * Eye(d))
cov2psdmatrix(cov::Diagonal; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(sqrt.(cov)))
cov2psdmatrix(cov::AbstractMatrix; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(cholesky(cov).U))
cov2psdmatrix(cov::PSDMatrix; d) = (@assert size(cov, 1) == size(cov, 2) == d; cov)

include("fast_linalg.jl")
include("kronecker.jl")
include("covariance_structure.jl")
Expand Down
14 changes: 10 additions & 4 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,21 @@ julia> solve(prob, EK0())

# [References](@ref references)
"""
struct EK0{PT,DT,IT} <: AbstractEK
struct EK0{PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
end
EK0(;
order=3,
prior=IWP(order),
diffusionmodel=DynamicDiffusion(),
smooth=true,
initialization=TaylorModeInit(num_derivatives(prior)),
) = EK0(prior, diffusionmodel, smooth, initialization)
pn_obervation_noise=nothing,
) = EK0(prior, diffusionmodel, smooth, initialization, pn_obervation_noise)

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B) = B
Expand Down Expand Up @@ -92,11 +94,12 @@ julia> solve(prob, EK1())

# [References](@ref references)
"""
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT} <: AbstractEK
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
end
EK1(;
order=3,
Expand All @@ -109,7 +112,8 @@ EK1(;
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
) where {PT,DT,IT} =
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} =
EK1{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
Expand All @@ -119,11 +123,13 @@ EK1(;
PT,
DT,
IT,
RT,
}(
prior,
diffusionmodel,
smooth,
initialization,
pn_observation_noise,
)

"""
Expand Down
8 changes: 6 additions & 2 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,14 @@ function OrdinaryDiffEq.alg_cache(
copy!(x0.Σ, apply_diffusion(x0.Σ, initdiff))

# Measurement model related things
R = nothing # factorized_similar(FAC, d, d)
R = if isnothing(alg.pn_observation_noise)
nothing
else
to_factorized_matrix(FAC, cov2psdmatrix(alg.pn_observation_noise; d))
end
H = factorized_similar(FAC, d, D)
v = similar(Array{uElType}, d)
S = PSDMatrix(factorized_zeros(FAC, D, d))
S = factorized_zeros(FAC, d, d)
measurement = Gaussian(v, S)

# Caches
Expand Down
15 changes: 7 additions & 8 deletions src/diffusions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function estimate_global_diffusion(::FixedDiffusion, integ)

v, S = measurement.μ, measurement.Σ
e = m_tmp.μ
_S = _matmul!(Smat, S.R', S.R)
_S = S
e .= v
diffusion_t = if _S isa IsometricKroneckerProduct
@assert length(_S.B) == 1
Expand Down Expand Up @@ -124,8 +124,7 @@ function estimate_global_diffusion(::FixedMVDiffusion, integ)

v, S = measurement.μ, measurement.Σ
# S_11 = diag(S)[1]
c1 = view(S.R, :, 1)
S_11 = dot(c1, c1)
S_11 = S[1, 1]

Σ_ii = v .^ 2 ./ S_11
Σ = Diagonal(Σ_ii)
Expand Down Expand Up @@ -159,11 +158,11 @@ For more background information
* [bosch20capos](@cite) Bosch et al, "Calibrated Adaptive Probabilistic ODE Solvers", AISTATS (2021)
"""
function local_scalar_diffusion(cache)
@unpack d, R, H, Qh, measurement, m_tmp, Smat = cache
@unpack d, R, H, Qh, measurement, m_tmp, Smat, C_Dxd = cache
z = measurement.μ
e, HQH = m_tmp.μ, m_tmp.Σ
fast_X_A_Xt!(HQH, Qh, H)
HQHmat = _matmul!(Smat, HQH.R', HQH.R)
_matmul!(C_Dxd, Qh.R, H')
HQHmat = _matmul!(Smat, C_Dxd', C_Dxd)
e .= z
σ² = if HQHmat isa IsometricKroneckerProduct
@assert length(HQHmat.B) == 1
Expand Down Expand Up @@ -195,9 +194,9 @@ function local_diagonal_diffusion(cache)
@unpack d, q, H, Qh, measurement, m_tmp, tmp = cache
@unpack local_diffusion = cache
z = measurement.μ
HQH = fast_X_A_Xt!(m_tmp.Σ, Qh, H)
HQHR = _matmul!(cache.C_Dxd, Qh.R, H')
# Q0_11 = diag(HQH)[1]
c1 = view(HQH.R, :, 1)
c1 = view(HQHR, :, 1)
nathanaelbosch marked this conversation as resolved.
Show resolved Hide resolved
Q0_11 = dot(c1, c1)

Σ_ii = @. m_tmp.μ = z^2 / Q0_11
Expand Down
18 changes: 13 additions & 5 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,16 @@ function update!(
fast_X_A_Xt!(x_out.Σ, P_p, M_cache)

if !isnothing(R)
x_out.Σ.R .= triangularize!([x_out.Σ.R; R.R * K']; cachemat=M_cache)
# M = Matrix(x_out.Σ) + K * Matrix(R) * K'
_matmul!(M_cache, x_out.Σ.R', x_out.Σ.R)
_matmul!(K1_cache, K, R.R')
_matmul!(M_cache, K1_cache, K1_cache', 1, 1)
chol = cholesky!(Symmetric(M_cache), check=false)
if issuccess(chol)
copy!(x_out.Σ.R, chol.U)
else
x_out.Σ.R .= triangularize!([x_out.Σ.R; R.R * K']; cachemat=M_cache)
nathanaelbosch marked this conversation as resolved.
Show resolved Hide resolved
end
end

return x_out, loglikelihood
Expand All @@ -141,7 +150,7 @@ end
function update!(
x_out::SRGaussian{T,<:IsometricKroneckerProduct},
x_pred::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::Gaussian{<:AbstractVector,<:IsometricKroneckerProduct},
H::IsometricKroneckerProduct,
K1_cache::IsometricKroneckerProduct,
K2_cache::IsometricKroneckerProduct,
Expand All @@ -155,8 +164,7 @@ function update!(
Q = D ÷ d # n_derivatives_dim
_x_out = Gaussian(reshape_no_alloc(x_out.μ, Q, d), PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian(reshape_no_alloc(x_pred.μ, Q, d), PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian(
reshape_no_alloc(measurement.μ, 1, d), PSDMatrix(measurement.Σ.R.B))
_measurement = Gaussian(reshape_no_alloc(measurement.μ, 1, d), measurement.Σ.B)
_H = H.B
_K1_cache = K1_cache.B
_K2_cache = K2_cache.B
Expand All @@ -180,7 +188,7 @@ function update!(
end

# Short-hand with cache
function update!(x_out, x, measurement, H; R=nothing, cache)
function update!(x_out, x, measurement, H; cache, R=nothing)
@unpack K1, m_tmp, C_DxD, C_dxd, C_Dxd, C_d = cache
K2 = C_Dxd
return update!(x_out, x, measurement, H, K1, K2, C_DxD, C_dxd, C_d; R)
Expand Down
3 changes: 2 additions & 1 deletion src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ function init_condition_on!(
m_tmp.μ .-= data

# measurement cov
fast_X_A_Xt!(m_tmp.Σ, x.Σ, H)
_matmul!(C_Dxd, x.Σ.R, H')
_matmul!(m_tmp.Σ, C_Dxd', C_Dxd)
copy!(x_tmp, x)
update!(x, x_tmp, m_tmp, H, K1, C_Dxd, C_DxD, C_dxd, C_d)
end
12 changes: 8 additions & 4 deletions src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ function OrdinaryDiffEq.perform_step!(integ, cache::EKCache, repeat_step=false)
compute_measurement_covariance!(cache)

# Update state and save the ODE solution value
x_filt, loglikelihood = update!(x_filt, x_pred, cache.measurement, cache.H; cache)
x_filt, loglikelihood = update!(
x_filt, x_pred, cache.measurement, cache.H; cache, R=cache.R)
write_into_solution!(
integ.u, x_filt.μ; cache, is_secondorder_ode=integ.f isa DynamicalODEFunction)

Expand Down Expand Up @@ -159,8 +160,11 @@ function evaluate_ode!(integ, x_pred, t)
end

compute_measurement_covariance!(cache) = begin
@assert isnothing(cache.R)
fast_X_A_Xt!(cache.measurement.Σ, cache.x_pred.Σ, cache.H)
_matmul!(cache.C_Dxd, cache.x_pred.Σ.R, cache.H')
_matmul!(cache.measurement.Σ, cache.C_Dxd', cache.C_Dxd)
if !isnothing(cache.R)
cache.measurement.Σ .+= _matmul!(cache.C_dxd, cache.R.R', cache.R.R)
end
end

"""
Expand Down Expand Up @@ -216,7 +220,7 @@ To save allocations, the function modifies the given `cache` and writes into
function estimate_errors!(cache::AbstractODEFilterCache)
@unpack local_diffusion, Qh, H, d = cache

R = cache.measurement.Σ.R
R = cache.C_Dxd

if local_diffusion isa Diagonal
_QR = cache.C_DxD .= Qh.R .* sqrt.(local_diffusion.diag)'
Expand Down
Loading