Skip to content

Commit

Permalink
Implement observation noise for the PN likelihood (#299)
Browse files Browse the repository at this point in the history
* Implement observation noise for the PN likelihood

* Fix some issues that this created with the EK0

* Add check to make sure that calibration is off when R>0

* Use cov2psdmatrix also in the data likelihoods

* Slightly more verbose error if R>0 and calibrate=true

* More compact code in caches.jl

* Add a warning whenever we use `triangularize!`

just to be more aware of what acually happens

* Make multivariate diffusions illegal with the EK1

* Revisit some diffusion-related stuff in diffusions.jl

* Fix a bug with classicsolverinit

* Make update compatible with PSDMatrix-values marginal obs covs again

* Properly fix the RKinit bug

* JuliaFormatter.jl

* Fix more tests

* JuliaFormatter.jl _again_

* Algorithm check was bad; this fixes it

* JuliaFormatter.jl
  • Loading branch information
nathanaelbosch authored Feb 9, 2024
1 parent 70aee09 commit c430bb8
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 91 deletions.
8 changes: 8 additions & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ X_A_Xt(A, X) = X * A * X'
stack(x) = copy(reduce(hcat, x)')
vecvec2mat(x) = reduce(hcat, x)'

cov2psdmatrix(cov::Number; d) = PSDMatrix(sqrt(cov) * Eye(d))
cov2psdmatrix(cov::UniformScaling; d) = PSDMatrix(sqrt(cov.λ) * Eye(d))
cov2psdmatrix(cov::Diagonal; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(sqrt.(cov)))
cov2psdmatrix(cov::AbstractMatrix; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(Matrix(cholesky(cov).U)))
cov2psdmatrix(cov::PSDMatrix; d) = (@assert size(cov, 1) == size(cov, 2) == d; cov)

include("fast_linalg.jl")
include("kronecker.jl")
include("covariance_structure.jl")
Expand Down
103 changes: 67 additions & 36 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
########################################################################################
abstract type AbstractEK <: OrdinaryDiffEq.OrdinaryDiffEqAdaptiveAlgorithm end

function ekargcheck(alg; diffusionmodel, pn_observation_noise, kwargs...)
if (isstatic(diffusionmodel) && diffusionmodel.calibrate) &&
(!isnothing(pn_observation_noise) && !iszero(pn_observation_noise))
throw(
ArgumentError(
"Automatic calibration of global diffusion models is not possible when using observation noise. If you want to calibrate a global diffusion parameter, do so setting `calibrate=false` and optimizing `sol.pnstats.log_likelihood` manually.",
),
)
end
if (
(diffusionmodel isa FixedMVDiffusion && diffusionmodel.calibrate) ||
diffusionmodel isa DynamicMVDiffusion) && alg == EK1
throw(
ArgumentError(
"The `EK1` algorithm does not support automatic calibration of multivariate diffusion models. Either use the `EK0` instead, or use a scalar diffusion model, or set `calibrate=false` and calibrate manually by optimizing `sol.pnstats.log_likelihood`.",
),
)
end
end

"""
EK0(; order=3,
smooth=true,
Expand Down Expand Up @@ -38,19 +58,24 @@ julia> solve(prob, EK0())
# [References](@ref references)
"""
struct EK0{PT,DT,IT} <: AbstractEK
struct EK0{PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
EK0(; order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(EK0; diffusionmodel, pn_observation_noise)
new{PT,DT,IT,RT}(
prior, diffusionmodel, smooth, initialization, pn_observation_noise)
end
end
EK0(;
order=3,
prior=IWP(order),
diffusionmodel=DynamicDiffusion(),
smooth=true,
initialization=TaylorModeInit(num_derivatives(prior)),
) = EK0(prior, diffusionmodel, smooth, initialization)

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B) = B
Expand Down Expand Up @@ -92,39 +117,45 @@ julia> solve(prob, EK1())
# [References](@ref references)
"""
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT} <: AbstractEK
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
PT,
DT,
IT,
RT,
}(
prior,
diffusionmodel,
smooth,
initialization,
pn_observation_noise,
)
end
end
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
) where {PT,DT,IT} =
EK1{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
PT,
DT,
IT,
}(
prior,
diffusionmodel,
smooth,
initialization,
)

"""
ExpEK(; L, order=3, kwargs...)
Expand Down
6 changes: 4 additions & 2 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ function OrdinaryDiffEq.alg_cache(
copy!(x0.Σ, apply_diffusion(x0.Σ, initdiff))

# Measurement model related things
R = nothing # factorized_similar(FAC, d, d)
R =
isnothing(alg.pn_observation_noise) ? nothing :
to_factorized_matrix(FAC, cov2psdmatrix(alg.pn_observation_noise; d))
H = factorized_similar(FAC, d, D)
v = similar(Array{uElType}, d)
S = PSDMatrix(factorized_zeros(FAC, D, d))
S = factorized_zeros(FAC, d, d)
measurement = Gaussian(v, S)

# Caches
Expand Down
10 changes: 1 addition & 9 deletions src/callbacks/dataupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@ function DataUpdateCallback(
obs_mean = _matmul!(view(m_tmp.μ, 1:o), H, x.μ)
obs_mean .-= val

R = if observation_noise_cov isa PSDMatrix
observation_noise_cov
elseif observation_noise_cov isa Number
PSDMatrix(sqrt(observation_noise_cov) * Eye(o))
elseif observation_noise_cov isa UniformScaling
PSDMatrix(sqrt(observation_noise_cov.λ) * Eye(o))
else
PSDMatrix(cholesky(observation_noise_cov).U)
end
R = cov2psdmatrix(observation_noise_cov; d=o)

# _A = x.Σ.R * H'
# obs_cov = _A'_A + R
Expand Down
20 changes: 6 additions & 14 deletions src/data_likelihoods/fenrir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,7 @@ function fenrir_data_loglik(

# Fit the ODE solution / PN posterior to the provided data; this is the actual Fenrir
o = length(data.u[1])
R = if observation_noise_cov isa PSDMatrix
observation_noise_cov
elseif observation_noise_cov isa Number
PSDMatrix(sqrt(observation_noise_cov) * Eye(o))
elseif observation_noise_cov isa UniformScaling
PSDMatrix(sqrt(observation_noise_cov.λ) * Eye(o))
else
PSDMatrix(cholesky(observation_noise_cov).U)
end
R = cov2psdmatrix(observation_noise_cov; d=o)
LL, _, _ = fit_pnsolution_to_data!(sol, R, data; proj=observation_matrix)

return LL
Expand All @@ -91,7 +83,7 @@ function fit_pnsolution_to_data!(
C_d=view(C_d, 1:o),
K1=view(K1, :, 1:o),
K2=view(C_Dxd, :, 1:o),
m_tmp=Gaussian(view(m_tmp.μ, 1:o), PSDMatrix(view(m_tmp.Σ.R, :, 1:o))),
m_tmp=Gaussian(view(m_tmp.μ, 1:o), view(m_tmp.Σ, 1:o, 1:o)),
)

x_posterior = copy(sol.x_filt) # the object to be filled
Expand Down Expand Up @@ -144,10 +136,10 @@ function measure_and_update!(x, u, H, R::PSDMatrix, cache)
z, S = cache.m_tmp
_matmul!(z, H, x.μ)
z .-= u
fast_X_A_Xt!(S, x.Σ, H)
# _S = PSDMatrix(S.R'S.R + R.R'R.R)
_S = PSDMatrix(triangularize!([S.R; R.R], cachemat=cache.C_DxD))
msmnt = Gaussian(z, _S)
_matmul!(cache.C_Dxd, x.Σ.R, H')
_matmul!(S, cache.C_Dxd', cache.C_Dxd)
S .+= _matmul!(cache.C_dxd, R.R', R.R)
msmnt = Gaussian(z, S)

return update!(x, copy!(cache.x_tmp, x), msmnt, H; R=R, cache)
end
39 changes: 19 additions & 20 deletions src/diffusions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,12 @@ function estimate_global_diffusion(::FixedDiffusion, integ)

v, S = measurement.μ, measurement.Σ
e = m_tmp.μ
_S = _matmul!(Smat, S.R', S.R)
e .= v
diffusion_t = if _S isa IsometricKroneckerProduct
@assert length(_S.B) == 1
dot(v, e) / d / _S.B[1]
diffusion_t = if S isa IsometricKroneckerProduct
@assert length(S.B) == 1
dot(v, e) / d / S.B[1]
else
S_chol = cholesky!(_S)
S_chol = cholesky!(S)
ldiv!(S_chol, e)
dot(v, e) / d
end
Expand Down Expand Up @@ -123,13 +122,12 @@ function estimate_global_diffusion(::FixedMVDiffusion, integ)
@unpack d, q, measurement, local_diffusion = integ.cache

v, S = measurement.μ, measurement.Σ
# S_11 = diag(S)[1]
c1 = view(S.R, :, 1)
S_11 = dot(c1, c1)
# @assert diag(S) |> unique |> length == 1
S_11 = S[1, 1]

Σ_ii = v .^ 2 ./ S_11
Σ = Diagonal(Σ_ii)
Σ_out = kron(Σ, I(q + 1))
Σ_out = kron(Σ, I(q + 1)) # -> Different for each dimension; same for each derivative

if integ.success_iter == 0
# @assert length(diffusions) == 0
Expand Down Expand Up @@ -159,17 +157,17 @@ For more background information
* [bosch20capos](@cite) Bosch et al, "Calibrated Adaptive Probabilistic ODE Solvers", AISTATS (2021)
"""
function local_scalar_diffusion(cache)
@unpack d, R, H, Qh, measurement, m_tmp, Smat = cache
@unpack d, R, H, Qh, measurement, m_tmp, Smat, C_Dxd = cache
z = measurement.μ
e, HQH = m_tmp.μ, m_tmp.Σ
fast_X_A_Xt!(HQH, Qh, H)
HQHmat = _matmul!(Smat, HQH.R', HQH.R)
_matmul!(C_Dxd, Qh.R, H')
_matmul!(HQH, C_Dxd', C_Dxd)
e .= z
σ² = if HQHmat isa IsometricKroneckerProduct
@assert length(HQHmat.B) == 1
dot(z, e) / d / HQHmat.B[1]
σ² = if HQH isa IsometricKroneckerProduct
@assert length(HQH.B) == 1
dot(z, e) / d / HQH.B[1]
else
C = cholesky!(HQHmat)
C = cholesky!(HQH)
ldiv!(C, e)
dot(z, e) / d
end
Expand All @@ -195,16 +193,17 @@ function local_diagonal_diffusion(cache)
@unpack d, q, H, Qh, measurement, m_tmp, tmp = cache
@unpack local_diffusion = cache
z = measurement.μ
HQH = fast_X_A_Xt!(m_tmp.Σ, Qh, H)
# Q0_11 = diag(HQH)[1]
c1 = view(HQH.R, :, 1)
# HQH = H * unfactorize(Qh) * H'
# @assert HQH |> diag |> unique |> length == 1
# c1 = view(_matmul!(cache.C_Dxd, Qh.R, H'), :, 1)
c1 = mul!(view(cache.C_Dxd, :, 1:1), Qh.R, view(H, 1:1, :)')
Q0_11 = dot(c1, c1)

Σ_ii = @. m_tmp.μ = z^2 / Q0_11
# Σ_ii .= max.(Σ_ii, eps(eltype(Σ_ii)))
Σ = Diagonal(Σ_ii)

# local_diffusion = kron(Σ, I(q+1))
# -> Different for each dimension; same for each derivative
for i in 1:d
for j in (i-1)*(q+1)+1:i*(q+1)
local_diffusion[j, j] = Σ[i, i]
Expand Down
22 changes: 18 additions & 4 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,16 @@ function update!(
fast_X_A_Xt!(x_out.Σ, P_p, M_cache)

if !isnothing(R)
x_out.Σ.R .= triangularize!([x_out.Σ.R; R.R * K']; cachemat=M_cache)
# M = Matrix(x_out.Σ) + K * Matrix(R) * K'
_matmul!(M_cache, x_out.Σ.R', x_out.Σ.R)
_matmul!(K1_cache, K, R.R')
_matmul!(M_cache, K1_cache, K1_cache', 1, 1)
chol = cholesky!(Symmetric(M_cache), check=false)
if issuccess(chol)
copy!(x_out.Σ.R, chol.U)
else
x_out.Σ.R .= triangularize!([x_out.Σ.R; K1_cache']; cachemat=M_cache)
end
end

return x_out, loglikelihood
Expand All @@ -141,7 +150,10 @@ end
function update!(
x_out::SRGaussian{T,<:IsometricKroneckerProduct},
x_pred::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::Gaussian{
<:AbstractVector,
<:Union{<:PSDMatrix{T,<:IsometricKroneckerProduct},<:IsometricKroneckerProduct},
},
H::IsometricKroneckerProduct,
K1_cache::IsometricKroneckerProduct,
K2_cache::IsometricKroneckerProduct,
Expand All @@ -156,7 +168,9 @@ function update!(
_x_out = Gaussian(reshape_no_alloc(x_out.μ, Q, d), PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian(reshape_no_alloc(x_pred.μ, Q, d), PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian(
reshape_no_alloc(measurement.μ, 1, d), PSDMatrix(measurement.Σ.R.B))
reshape_no_alloc(measurement.μ, 1, d),
measurement.Σ isa PSDMatrix ? PSDMatrix(measurement.Σ.R.B) : measurement.Σ.B,
)
_H = H.B
_K1_cache = K1_cache.B
_K2_cache = K2_cache.B
Expand All @@ -180,7 +194,7 @@ function update!(
end

# Short-hand with cache
function update!(x_out, x, measurement, H; R=nothing, cache)
function update!(x_out, x, measurement, H; cache, R=nothing)
@unpack K1, m_tmp, C_DxD, C_dxd, C_Dxd, C_d = cache
K2 = C_Dxd
return update!(x_out, x, measurement, H, K1, K2, C_DxD, C_dxd, C_d; R)
Expand Down
3 changes: 2 additions & 1 deletion src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt)

H = cache.E0 * PI
measurement.μ .= H * x_pred.μ .- u
fast_X_A_Xt!(measurement.Σ, x_pred.Σ, H)
_matmul!(C_Dxd, x_pred.Σ.R, H')
_matmul!(measurement.Σ, C_Dxd', C_Dxd)

update!(x_filt, x_pred, measurement, H, K1, C_Dxd, C_DxD, C_dxd, C_d)
push!(filts, copy(x_filt))
Expand Down
3 changes: 2 additions & 1 deletion src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ function init_condition_on!(
m_tmp.μ .-= data

# measurement cov
fast_X_A_Xt!(m_tmp.Σ, x.Σ, H)
_matmul!(C_Dxd, x.Σ.R, H')
_matmul!(m_tmp.Σ, C_Dxd', C_Dxd)
copy!(x_tmp, x)
update!(x, x_tmp, m_tmp, H, K1, C_Dxd, C_DxD, C_dxd, C_d)
end
Loading

0 comments on commit c430bb8

Please sign in to comment.