From 5519c314223ee05ec8f90d3fae5a6ea9df2d0747 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Tue, 13 Feb 2024 11:41:24 +0100 Subject: [PATCH] Implement a first version of the DiagonalEK1 --- src/ProbNumDiffEq.jl | 2 +- src/alg_utils.jl | 14 ++++++------ src/algorithms.jl | 43 ++++++++++++++++++++++++++++++++++++- src/caches.jl | 9 ++++---- src/covariance_structure.jl | 2 ++ src/derivative_utils.jl | 18 +++++++++++++++- 6 files changed, 75 insertions(+), 13 deletions(-) diff --git a/src/ProbNumDiffEq.jl b/src/ProbNumDiffEq.jl index 802c653b7..d4c73da69 100644 --- a/src/ProbNumDiffEq.jl +++ b/src/ProbNumDiffEq.jl @@ -74,7 +74,7 @@ include("initialization/common.jl") export TaylorModeInit, ClassicSolverInit, SimpleInit, ForwardDiffInit include("algorithms.jl") -export EK0, EK1 +export EK0, EK1, DiagonalEK1 export ExpEK, RosenbrockExpEK include("alg_utils.jl") diff --git a/src/alg_utils.jl b/src/alg_utils.jl index a3c84736f..a845a0ad4 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -4,19 +4,21 @@ ############################################################################################ OrdinaryDiffEq._alg_autodiff(::AbstractEK) = Val{true}() -OrdinaryDiffEq._alg_autodiff(::EK1{CS,AD}) where {CS,AD} = Val{AD}() -OrdinaryDiffEq.alg_difftype(::EK1{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType OrdinaryDiffEq.standardtag(::AbstractEK) = false -OrdinaryDiffEq.standardtag(::EK1{CS,AD,DiffType,ST}) where {CS,AD,DiffType,ST} = ST OrdinaryDiffEq.concrete_jac(::AbstractEK) = nothing -OrdinaryDiffEq.concrete_jac(::EK1{CS,AD,DiffType,ST,CJ}) where {CS,AD,DiffType,ST,CJ} = CJ @inline DiffEqBase.get_tmp_cache(integ, alg::AbstractEK, cache::AbstractODEFilterCache) = (cache.tmp, cache.atmp) -OrdinaryDiffEq.get_chunksize(::EK1{CS}) where {CS} = Val(CS) OrdinaryDiffEq.isfsal(::AbstractEK) = false -OrdinaryDiffEq.isimplicit(::EK1) = true +for ALG in [:EK1, :DiagonalEK1] + @eval OrdinaryDiffEq._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} = Val{AD}() + @eval OrdinaryDiffEq.alg_difftype(::$ALG{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType + @eval OrdinaryDiffEq.standardtag(::$ALG{CS,AD,DiffType,ST}) where {CS,AD,DiffType,ST} = ST + @eval OrdinaryDiffEq.concrete_jac(::$ALG{CS,AD,DiffType,ST,CJ}) where {CS,AD,DiffType,ST,CJ} = CJ + @eval OrdinaryDiffEq.get_chunksize(::$ALG{CS}) where {CS} = Val(CS) + @eval OrdinaryDiffEq.isimplicit(::$ALG) = true +end ############################################ # Step size control diff --git a/src/algorithms.jl b/src/algorithms.jl index 9982f0f5d..c084fcd42 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -157,6 +157,47 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK end end +struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK + prior::PT + diffusionmodel::DT + smooth::Bool + initialization::IT + pn_observation_noise::RT + DiagonalEK1(; + order=3, + prior::PT=IWP(order), + diffusionmodel::DT=DynamicDiffusion(), + smooth=true, + initialization::IT=TaylorModeInit(num_derivatives(prior)), + chunk_size=Val{0}(), + autodiff=Val{true}(), + diff_type=Val{:forward}, + standardtag=Val{true}(), + concrete_jac=nothing, + pn_observation_noise::RT=nothing, + ) where {PT,DT,IT,RT} = begin + ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise) + new{ + _unwrap_val(chunk_size), + _unwrap_val(autodiff), + diff_type, + _unwrap_val(standardtag), + _unwrap_val(concrete_jac), + PT, + DT, + IT, + RT, + }( + prior, + diffusionmodel, + smooth, + initialization, + pn_observation_noise, + ) + end +end + + """ ExpEK(; L, order=3, kwargs...) @@ -236,7 +277,7 @@ function DiffEqBase.remake(thing::EK1{CS,AD,DT,ST,CJ}; kwargs...) where {CS,AD,D ) end -function DiffEqBase.prepare_alg(alg::EK1{0}, u0::AbstractArray{T}, p, prob) where {T} +function DiffEqBase.prepare_alg(alg::Union{EK1{0},DiagonalEK1{0}}, u0::AbstractArray{T}, p, prob) where {T} # See OrdinaryDiffEq.jl: ./src/alg_utils.jl (where this is copied from). # In the future we might want to make EK1 an OrdinaryDiffEqAdaptiveImplicitAlgorithm and # use the prepare_alg from OrdinaryDiffEq; but right now, we do not use `linsolve` which diff --git a/src/caches.jl b/src/caches.jl index ab3e41917..93bd4920e 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -5,7 +5,7 @@ mutable struct EKCache{ RType,CFacType,ProjType,SolProjType,PType,PIType,EType,uType,duType,xType,PriorType, AType,QType, FType,LType,FHGMethodType,FHGCacheType, - HType,vecType,matType,bkType,diffusionType,diffModelType,measModType,measType, + HType,vecType,dduType,matType,bkType,diffusionType,diffModelType,measModType,measType, puType,llType,dtType,rateType,UF,JC,uNoUnitsType, } <: AbstractODEFilterCache # Constants @@ -49,7 +49,7 @@ mutable struct EKCache{ pu_tmp::puType H::HType du::duType - ddu::matType + ddu::dduType K1::matType G1::matType Smat::HType @@ -178,7 +178,8 @@ function OrdinaryDiffEq.alg_cache( # Caches du = is_secondorder_ode ? similar(u.x[2]) : similar(u) - ddu = factorized_similar(FAC, length(u), length(u)) + # ddu = factorized_similar(FAC, length(u), length(u)) + ddu = similar(u, length(u), length(u)) _d = is_secondorder_ode ? 2d : d pu_tmp = Gaussian( similar(Array{uElType}, _d), @@ -242,7 +243,7 @@ function OrdinaryDiffEq.alg_cache( typeof(R),typeof(FAC),typeof(Proj),typeof(SolProj),typeof(P),typeof(PI),typeof(E0), uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q), typeof(F),typeof(L),typeof(FHG_method),typeof(FHG_cache), - typeof(H),typeof(C_d),matType, + typeof(H),typeof(C_d),typeof(ddu),matType, typeof(backward_kernel),typeof(initdiff), typeof(diffmodel),typeof(measurement_model),typeof(measurement),typeof(pu_tmp), uEltypeNoUnits,typeof(dt),typeof(du1),typeof(uf),typeof(jac_config),typeof(atmp), diff --git a/src/covariance_structure.jl b/src/covariance_structure.jl index 39d46126e..46955c4d5 100644 --- a/src/covariance_structure.jl +++ b/src/covariance_structure.jl @@ -22,6 +22,8 @@ function get_covariance_structure(alg; elType, d, q) alg.prior isa IWP ) return IsometricKroneckerCovariance{elType}(d, q) + elseif alg isa DiagonalEK1 + return BlockDiagonalCovariance{elType}(d, q) else return DenseCovariance{elType}(d, q) end diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 9b84fb1c8..bd08e409c 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -8,7 +8,23 @@ function calc_H!(H, integ, cache) calc_H_EK0!(H, integ, cache) # @assert integ.u == @view x_pred.μ[1:(q+1):end] OrdinaryDiffEq.calc_J!(ddu, integ, cache, true) - ProbNumDiffEq._matmul!(H, view(ddu, 1:d, :), cache.SolProj, -1.0, 1.0) + _matmul!(H, view(ddu, 1:d, :), cache.SolProj, -1.0, 1.0) + elseif integ.alg isa DiagonalEK1 + calc_H_EK0!(H, integ, cache) + # @assert integ.u == @view x_pred.μ[1:(q+1):end] + # ddu_full = Matrix(ddu) + # @info "ddu" ddu_full + # error() + OrdinaryDiffEq.calc_J!(ddu, integ, cache, true) + + @unpack C_dxd = cache + @simd ivdep for i in eachindex(blocks(C_dxd)) + @assert length(C_dxd.blocks[i]) == 1 + C_dxd.blocks[i][1] = ddu[i, i] + end + _matmul!(H, C_dxd, cache.SolProj, -1.0, 1.0) + else + error("Unknown algorithm") end return nothing end