From 82b2a8f1d2fec853952cf088fe3b56e7e4c69938 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Sat, 28 Oct 2023 16:12:40 +0200 Subject: [PATCH] Shorten some code now that CovarianceStructure holds elType, d, q --- src/caches.jl | 2 +- src/preconditioning.jl | 22 ++++++---------------- src/priors/common.jl | 7 ++++--- src/projection.jl | 22 ++++++---------------- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/src/caches.jl b/src/caches.jl index a7fddfdb7..7c84b79ca 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -106,7 +106,7 @@ function OrdinaryDiffEq.alg_cache( matType = typeof(factorized_similar(FAC, d, d)) # Projections - Proj = projection(FAC, d, q, uElType) + Proj = projection(FAC) E0, E1, E2 = Proj(0), Proj(1), Proj(2) @assert f isa SciMLBase.AbstractODEFunction SolProj = if is_secondorder_ode diff --git a/src/preconditioning.jl b/src/preconditioning.jl index ec4b0a032..5c334f8e9 100644 --- a/src/preconditioning.jl +++ b/src/preconditioning.jl @@ -1,21 +1,11 @@ -function init_preconditioner( - FAC::IsometricKroneckerCovariance, - d, - q, - ::Type{elType}=typeof(1.0), -) where {elType} - P = IsometricKroneckerProduct(d, Diagonal(ones(elType, q + 1))) - PI = IsometricKroneckerProduct(d, Diagonal(ones(elType, q + 1))) +function init_preconditioner(C::IsometricKroneckerCovariance{elType}) where {elType} + P = IsometricKroneckerProduct(C.d, Diagonal(ones(elType, C.q + 1))) + PI = IsometricKroneckerProduct(C.d, Diagonal(ones(elType, C.q + 1))) return P, PI end -function init_preconditioner( - FAC::DenseCovariance, - d, - q, - ::Type{elType}=typeof(1.0), -) where {elType} - P = kron(I(d), Diagonal(ones(elType, q + 1))) - PI = kron(I(d), Diagonal(ones(elType, q + 1))) +function init_preconditioner(C::DenseCovariance{elType}) where {elType} + P = kron(I(C.d), Diagonal(ones(elType, C.q + 1))) + PI = kron(I(C.d), Diagonal(ones(elType, C.q + 1))) return P, PI end diff --git a/src/priors/common.jl b/src/priors/common.jl index 1c8513a40..a4cad3f70 100644 --- a/src/priors/common.jl +++ b/src/priors/common.jl @@ -1,12 +1,13 @@ abstract type AbstractODEFilterPrior{elType} end function initialize_preconditioner( - FAC::CovarianceStructure, + FAC::CovarianceStructure{T1}, p::AbstractODEFilterPrior{T}, dt, -) where {T} +) where {T,T1} + @assert T == T1 d, q = p.wiener_process_dimension, p.num_derivatives - P, PI = init_preconditioner(FAC, d, q, T) + P, PI = init_preconditioner(FAC) make_preconditioner!(P, dt, d, q) make_preconditioner_inv!(PI, dt, d, q) return P, PI diff --git a/src/projection.jl b/src/projection.jl index 03e92ba34..1a2c31ec8 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -13,27 +13,17 @@ function projection( end return Proj end -function projection( - ::DenseCovariance, - d::Integer, - q::Integer, - ::Type{elType}=typeof(1.0), -) where {elType} - projection(d, q, elType) +function projection(C::DenseCovariance{elType}) where {elType} + projection(C.d, C.q, elType) end -function projection( - ::IsometricKroneckerCovariance, - d::Integer, - q::Integer, - ::Type{elType}=typeof(1.0), -) where {elType} +function projection(C::IsometricKroneckerCovariance{elType}) where {elType} Proj(deriv) = begin - e_i = zeros(elType, q + 1, 1) - if deriv <= q + e_i = zeros(elType, C.q + 1, 1) + if deriv <= C.q e_i[deriv+1] = 1 end - return IsometricKroneckerProduct(d, e_i') + return IsometricKroneckerProduct(C.d, e_i') end return Proj end