diff --git a/src/ProbNumDiffEq.jl b/src/ProbNumDiffEq.jl index 4a6cc6a0e..f6f176929 100644 --- a/src/ProbNumDiffEq.jl +++ b/src/ProbNumDiffEq.jl @@ -41,6 +41,7 @@ vecvec2mat(x) = reduce(hcat, x)' include("fast_linalg.jl") include("kronecker.jl") +include("covariance_factorizations.jl") abstract type AbstractODEFilterCache <: OrdinaryDiffEq.OrdinaryDiffEqCache end diff --git a/src/algorithms.jl b/src/algorithms.jl index 96a043b04..efba620de 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -111,15 +111,6 @@ EK1(; initialization, ) -iskronecker(alg, f) = ( - alg isa EK0 - && - !(alg.diffusionmodel isa DynamicMVDiffusion || - alg.diffusionmodel isa FixedMVDiffusion) - && alg.prior isa IWP - && f.mass_matrix isa UniformScaling -) - """ ExpEK(; L, order=3, kwargs...) diff --git a/src/caches.jl b/src/caches.jl index 8860c7954..36003b01b 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -91,7 +91,10 @@ function OrdinaryDiffEq.alg_cache( d = is_secondorder_ode ? length(u[1, :]) : length(u) D = d * (q + 1) - KRONECKER = iskronecker(alg, f) + FAC = get_covariance_factorization(alg) + if FAC isa KroneckerCovariance && !(f.mass_matrix isa UniformScaling) + error("The selected algorithm uses an efficient Kronecker-factorized implementation which is incompatible with the provided mass matrix. Try using the `EK1` instead.") + end uType = typeof(u) # uElType = eltype(u_vec) @@ -99,11 +102,11 @@ function OrdinaryDiffEq.alg_cache( matType = Matrix{uElType} # Projections - Proj = projection(d, q, uElType) + Proj = projection(FAC, d, q, uElType) E0, E1, E2 = Proj(0), Proj(1), Proj(2) @assert f isa SciMLBase.AbstractODEFunction SolProj = if is_secondorder_ode - if KRONECKER + if E0 isa IKP IsoKroneckerProduct(d, [Proj(1).B; Proj(0).B]) else SolProj = [Proj(1); Proj(0)] @@ -111,10 +114,6 @@ function OrdinaryDiffEq.alg_cache( else Proj(0) end - if !KRONECKER - E0, E1, E2 = Matrix(E0), Matrix(E1), Matrix(E2) - SolProj = Matrix(SolProj) - end # Prior dynamics prior = if alg.prior isa IWP @@ -129,11 +128,7 @@ function OrdinaryDiffEq.alg_cache( else error("Invalid prior $(alg.prior)") end - A, Q, Ah, Qh, P, PI = initialize_transition_matrices(prior, dt) - if !KRONECKER - P, PI = Diagonal(P), Diagonal(PI) - A, Q, Ah, Qh = Matrix(A), PSDMatrix(Matrix(Q.R)), Matrix(Ah), PSDMatrix(Matrix(Qh.R)) - end + A, Q, Ah, Qh, P, PI = initialize_transition_matrices(FAC, prior, dt) # Measurement Model measurement_model = make_measurement_model(f) @@ -141,11 +136,7 @@ function OrdinaryDiffEq.alg_cache( # Initial State initial_variance = ones(uElType, q + 1) μ0 = zeros(uElType, D) - Σ0 = PSDMatrix(if KRONECKER - IsoKroneckerProduct(d, diagm(sqrt.(initial_variance))) - else - kron(I(d), diagm(sqrt.(initial_variance))) - end) + Σ0 = PSDMatrix(to_factorized_matrix(FAC, IsoKroneckerProduct(d, diagm(sqrt.(initial_variance))))) x0 = Gaussian(μ0, Σ0) # Diffusion Model @@ -155,17 +146,9 @@ function OrdinaryDiffEq.alg_cache( # Measurement model related things R = zeros(uElType, d, d) - H = if KRONECKER - IsoKroneckerProduct(d, zeros(uElType, 1, q + 1)) - else - zeros(uElType, d, D) - end + H = factorized_zeros(FAC, uElType, d, D; d, q) v = zeros(uElType, d) - S = if KRONECKER - PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, q + 1))) - else - PSDMatrix(zeros(uElType, D, d)) - end + S = PSDMatrix(factorized_zeros(FAC, uElType, D, d; d, q)) measurement = Gaussian(v, S) # Caches @@ -174,22 +157,11 @@ function OrdinaryDiffEq.alg_cache( pu_tmp = if !is_secondorder_ode # same dimensions as `measurement` copy(measurement) else # then `u` has 2d dimensions - Gaussian( - zeros(uElType, 2d), - PSDMatrix( - if KRONECKER - IsoKroneckerProduct(d, zeros(uElType, q+1, 2)) - else - zeros(uElType, D, 2d) - end)) + Gaussian(zeros(uElType, 2d), PSDMatrix(factorized_zeros(FAC, uElType, D, 2d; d, q))) end K = zeros(uElType, D, d) G = zeros(uElType, D, D) - Smat = if KRONECKER - IsoKroneckerProduct(d, zeros(uElType, 1, 1)) - else - zeros(uElType, d, d) - end + Smat = factorized_zeros(FAC, uElType, d, d; d, q) C_dxd = zeros(uElType, d, d) C_dxD = zeros(uElType, d, D) @@ -198,16 +170,11 @@ function OrdinaryDiffEq.alg_cache( C_2DxD = zeros(uElType, 2D, D) C_3DxD = zeros(uElType, 3D, D) - backward_kernel = if KRONECKER - AffineNormalKernel( - IsoKroneckerProduct(d, zeros(uElType, q+1, q+1)), - zeros(uElType, D), - PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, 2*(q+1), q+1))) - ) - else - AffineNormalKernel( - zeros(uElType, D, D), zeros(uElType, D), PSDMatrix(zeros(uElType, 2D, D))) - end + backward_kernel = AffineNormalKernel( + factorized_zeros(FAC, uElType, D, D; d, q), + zeros(uElType, D), + PSDMatrix(factorized_zeros(FAC, uElType, 2D, D; d, q)) + ) u_pred = copy(u) u_filt = copy(u) diff --git a/src/covariance_factorizations.jl b/src/covariance_factorizations.jl new file mode 100644 index 000000000..415f84803 --- /dev/null +++ b/src/covariance_factorizations.jl @@ -0,0 +1,31 @@ +abstract type CovarianceFactorization end +struct KroneckerCovariance <: CovarianceFactorization end +struct DenseCovariance <: CovarianceFactorization end + +function get_covariance_factorization(alg) + if ( + alg isa EK0 && + !( + alg.diffusionmodel isa DynamicMVDiffusion || + alg.diffusionmodel isa FixedMVDiffusion + ) && + alg.prior isa IWP + ) + return KroneckerCovariance() + else + return DenseCovariance() + end +end + +factorized_zeros(::KroneckerCovariance, elType, sizes...; d, q) = begin + for s in sizes + @assert s % d == 0 + end + return IsoKroneckerProduct(d, zeros(elType, (s ÷ d for s in sizes)...)) +end + +factorized_zeros(::DenseCovariance, elType, sizes...; d, q) = zeros(elType, sizes...) + +to_factorized_matrix(::DenseCovariance, M::AbstractMatrix) = Matrix(M) +to_factorized_matrix(::KroneckerCovariance, M::AbstractMatrix) = IsoKroneckerProduct(M) # probably errors +to_factorized_matrix(::KroneckerCovariance, M::IKP) = M diff --git a/src/kronecker.jl b/src/kronecker.jl index 4042a7e8b..4d801fdd4 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -18,8 +18,16 @@ Base.:*(K::IKP, a::Number) = IsoKroneckerProduct(K.ldim, K.B * a) Base.:*(a::Number, K::IKP) = IsoKroneckerProduct(K.ldim, a * K.B) LinearAlgebra.adjoint(A::IKP) = IsoKroneckerProduct(A.ldim, A.B') +function check_same_size(A::IKP, B::IKP) + if A.ldim != B.ldim || size(A.B) != size(B.B) + Ad, An, Am, Bd, Bn, Bm = A.ldim, size(A)..., B.ldim, size(B)... + throw( + DimensionMismatch("A has size ($Ad⋅$An,$Ad⋅$Am), B has size ($Bd⋅$Bn,$Bd⋅$Bm)"), + ) + end +end Base.:+(A::IKP, B::IKP) = begin - @assert A.ldim == B.ldim + check_same_size(A, B) return IsoKroneckerProduct(A.ldim, A.B + B.B) end Base.:+(U::UniformScaling, K::IKP) = IsoKroneckerProduct(K.ldim, U + K.B) @@ -68,7 +76,7 @@ Base.size(K::IKP) = (K.ldim * size(K.B, 1), K.ldim * size(K.B, 2)) # conversion Base.convert(::Type{T}, K::IKP) where {T<:IKP} = K isa T ? K : T(K) -function IKP{T,TB}(K::IKP) where {T,TA,TB} +function IKP{T,TB}(K::IKP) where {T,TB} IKP(K.ldim, convert(TB, K.B)) end @@ -80,6 +88,7 @@ reshape_no_alloc(a, dims::Tuple) = invoke(Base._reshape, Tuple{AbstractArray,typeof(dims)}, a, dims) # reshape_no_alloc(a::AbstractArray, dims::Tuple) = reshape(a, dims) reshape_no_alloc(a, dims...) = reshape_no_alloc(a, Tuple(dims)) +reshape_no_alloc(a::Missing, dims::Tuple) = missing function mul_vectrick!(x::AbstractVecOrMat, A::IsoKroneckerProduct, v::AbstractVecOrMat) N = A.B diff --git a/src/priors/common.jl b/src/priors/common.jl index 46f4658a3..0c3dd8bc1 100644 --- a/src/priors/common.jl +++ b/src/priors/common.jl @@ -31,7 +31,7 @@ See also: [`make_transition_matrices`](@ref). [1] N. Krämer, P. Hennig: **Stable Implementation of Probabilistic ODE Solvers** (2020) """ -function initialize_transition_matrices(p::AbstractODEFilterPrior{T}, dt) where {T} +function initialize_transition_matrices(::DenseCovariance, p::AbstractODEFilterPrior{T}, dt) where {T} d, q = p.wiener_process_dimension, p.num_derivatives D = d * (q + 1) Ah, Qh = zeros(T, D, D), PSDMatrix(zeros(T, D, D)) @@ -40,6 +40,9 @@ function initialize_transition_matrices(p::AbstractODEFilterPrior{T}, dt) where Q = copy(Qh) return A, Q, Ah, Qh, P, PI end +function initialize_transition_matrices(fac::CovarianceFactorization, p::AbstractODEFilterPrior, dt) + error("The chosen prior can not be implemented with a $fac factorization") +end """ make_transition_matrices!(cache, prior::AbstractODEFilterPrior, dt) diff --git a/src/priors/iwp.jl b/src/priors/iwp.jl index 22be1908a..d28a87a6d 100644 --- a/src/priors/iwp.jl +++ b/src/priors/iwp.jl @@ -105,13 +105,19 @@ function discretize(p::IWP, dt::Real) return A, Q end -function initialize_transition_matrices(p::IWP{T}, dt) where {T} +function initialize_transition_matrices(::KroneckerCovariance, p::IWP{T}, dt) where {T} A, Q = preconditioned_discretize(p) P, PI = initialize_preconditioner(p, dt) Ah = PI * A * P Qh = PSDMatrix(Q.R * PI) return A, Q, Ah, Qh, P, PI end +function initialize_transition_matrices(::DenseCovariance, p::IWP{T}, dt) where {T} + A, Q, Ah, Qh, P, PI = initialize_transition_matrices(KroneckerCovariance(), p, dt) + P, PI = Diagonal(P), Diagonal(PI) + A, Q, Ah, Qh = Matrix(A), PSDMatrix(Matrix(Q.R)), Matrix(Ah), PSDMatrix(Matrix(Qh.R)) + return A, Q, Ah, Qh, P, PI +end function make_transition_matrices!(cache, prior::IWP, dt) @unpack A, Q, Ah, Qh, P, PI = cache diff --git a/src/projection.jl b/src/projection.jl index 1c316d722..f9935a812 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -1,10 +1,32 @@ -function projection(d::Integer, q::Integer, ::Type{elType}=typeof(1.0)) where {elType} +function projection( + ::DenseCovariance, + d::Integer, + q::Integer, + ::Type{elType}=typeof(1.0), +) where {elType} + D = d * (q + 1) Proj(deriv) = begin - e_i = zeros(elType, q+1, 1) + P = zeros(elType, d, D) + @simd ivdep for i in deriv*d+1:D+1:d*D + @inbounds P[i] = 1 + end + return P + end + return Proj +end + +function projection( + ::KroneckerCovariance, + d::Integer, + q::Integer, + ::Type{elType}=typeof(1.0), +) where {elType} + Proj(deriv) = begin + e_i = zeros(elType, q + 1, 1) if deriv <= q e_i[deriv+1] = 1 end - IsoKroneckerProduct(d, e_i') + return IsoKroneckerProduct(d, e_i') end return Proj end diff --git a/src/solution.jl b/src/solution.jl index ea11fca32..bf27a6053 100644 --- a/src/solution.jl +++ b/src/solution.jl @@ -85,6 +85,7 @@ function DiffEqBase.build_solution( true, Val(isinplace(prob)), ) + q = cache.q T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) @@ -93,18 +94,10 @@ function DiffEqBase.build_solution( uElType = eltype(prob.u0) D = d - KRONECKER = iskronecker(alg, prob.f) + FAC = get_covariance_factorization(alg) - pu_cov = if KRONECKER - PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, D ÷ d + 1))) - else - PSDMatrix(zeros(uElType, D, d)) - end - x_cov = if KRONECKER - PSDMatrix(IsoKroneckerProduct(d, zeros(uElType, D ÷ d + 1, D ÷ d + 1))) - else - PSDMatrix(zeros(uElType, D, D)) - end + pu_cov = PSDMatrix(factorized_zeros(FAC, uElType, D, d; d, q)) + x_cov = PSDMatrix(factorized_zeros(FAC, uElType, D, D; d, q)) pu = StructArray{Gaussian{Vector{uElType},typeof(pu_cov)}}(undef, 0) x_filt = StructArray{Gaussian{Vector{uElType},typeof(x_cov)}}(undef, 0) x_smooth = copy(x_filt)