Skip to content

Commit

Permalink
Implement a first version of the DiagonalEK1
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 13, 2024
1 parent d0a3eb0 commit 5519c31
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ include("initialization/common.jl")
export TaylorModeInit, ClassicSolverInit, SimpleInit, ForwardDiffInit

include("algorithms.jl")
export EK0, EK1
export EK0, EK1, DiagonalEK1
export ExpEK, RosenbrockExpEK

include("alg_utils.jl")
Expand Down
14 changes: 8 additions & 6 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
############################################################################################

OrdinaryDiffEq._alg_autodiff(::AbstractEK) = Val{true}()
OrdinaryDiffEq._alg_autodiff(::EK1{CS,AD}) where {CS,AD} = Val{AD}()
OrdinaryDiffEq.alg_difftype(::EK1{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType
OrdinaryDiffEq.standardtag(::AbstractEK) = false
OrdinaryDiffEq.standardtag(::EK1{CS,AD,DiffType,ST}) where {CS,AD,DiffType,ST} = ST
OrdinaryDiffEq.concrete_jac(::AbstractEK) = nothing
OrdinaryDiffEq.concrete_jac(::EK1{CS,AD,DiffType,ST,CJ}) where {CS,AD,DiffType,ST,CJ} = CJ

@inline DiffEqBase.get_tmp_cache(integ, alg::AbstractEK, cache::AbstractODEFilterCache) =
(cache.tmp, cache.atmp)
OrdinaryDiffEq.get_chunksize(::EK1{CS}) where {CS} = Val(CS)
OrdinaryDiffEq.isfsal(::AbstractEK) = false

OrdinaryDiffEq.isimplicit(::EK1) = true
for ALG in [:EK1, :DiagonalEK1]
@eval OrdinaryDiffEq._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} = Val{AD}()
@eval OrdinaryDiffEq.alg_difftype(::$ALG{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType
@eval OrdinaryDiffEq.standardtag(::$ALG{CS,AD,DiffType,ST}) where {CS,AD,DiffType,ST} = ST
@eval OrdinaryDiffEq.concrete_jac(::$ALG{CS,AD,DiffType,ST,CJ}) where {CS,AD,DiffType,ST,CJ} = CJ
@eval OrdinaryDiffEq.get_chunksize(::$ALG{CS}) where {CS} = Val(CS)
@eval OrdinaryDiffEq.isimplicit(::$ALG) = true
end

############################################
# Step size control
Expand Down
43 changes: 42 additions & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,47 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK
end
end

struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
DiagonalEK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
PT,
DT,
IT,
RT,
}(
prior,
diffusionmodel,
smooth,
initialization,
pn_observation_noise,
)
end
end


"""
ExpEK(; L, order=3, kwargs...)
Expand Down Expand Up @@ -236,7 +277,7 @@ function DiffEqBase.remake(thing::EK1{CS,AD,DT,ST,CJ}; kwargs...) where {CS,AD,D
)
end

function DiffEqBase.prepare_alg(alg::EK1{0}, u0::AbstractArray{T}, p, prob) where {T}
function DiffEqBase.prepare_alg(alg::Union{EK1{0},DiagonalEK1{0}}, u0::AbstractArray{T}, p, prob) where {T}
# See OrdinaryDiffEq.jl: ./src/alg_utils.jl (where this is copied from).
# In the future we might want to make EK1 an OrdinaryDiffEqAdaptiveImplicitAlgorithm and
# use the prepare_alg from OrdinaryDiffEq; but right now, we do not use `linsolve` which
Expand Down
9 changes: 5 additions & 4 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mutable struct EKCache{
RType,CFacType,ProjType,SolProjType,PType,PIType,EType,uType,duType,xType,PriorType,
AType,QType,
FType,LType,FHGMethodType,FHGCacheType,
HType,vecType,matType,bkType,diffusionType,diffModelType,measModType,measType,
HType,vecType,dduType,matType,bkType,diffusionType,diffModelType,measModType,measType,
puType,llType,dtType,rateType,UF,JC,uNoUnitsType,
} <: AbstractODEFilterCache
# Constants
Expand Down Expand Up @@ -49,7 +49,7 @@ mutable struct EKCache{
pu_tmp::puType
H::HType
du::duType
ddu::matType
ddu::dduType
K1::matType
G1::matType
Smat::HType
Expand Down Expand Up @@ -178,7 +178,8 @@ function OrdinaryDiffEq.alg_cache(

# Caches
du = is_secondorder_ode ? similar(u.x[2]) : similar(u)
ddu = factorized_similar(FAC, length(u), length(u))
# ddu = factorized_similar(FAC, length(u), length(u))
ddu = similar(u, length(u), length(u))
_d = is_secondorder_ode ? 2d : d
pu_tmp = Gaussian(
similar(Array{uElType}, _d),
Expand Down Expand Up @@ -242,7 +243,7 @@ function OrdinaryDiffEq.alg_cache(
typeof(R),typeof(FAC),typeof(Proj),typeof(SolProj),typeof(P),typeof(PI),typeof(E0),
uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q),
typeof(F),typeof(L),typeof(FHG_method),typeof(FHG_cache),
typeof(H),typeof(C_d),matType,
typeof(H),typeof(C_d),typeof(ddu),matType,
typeof(backward_kernel),typeof(initdiff),
typeof(diffmodel),typeof(measurement_model),typeof(measurement),typeof(pu_tmp),
uEltypeNoUnits,typeof(dt),typeof(du1),typeof(uf),typeof(jac_config),typeof(atmp),
Expand Down
2 changes: 2 additions & 0 deletions src/covariance_structure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ function get_covariance_structure(alg; elType, d, q)
alg.prior isa IWP
)
return IsometricKroneckerCovariance{elType}(d, q)
elseif alg isa DiagonalEK1
return BlockDiagonalCovariance{elType}(d, q)
else
return DenseCovariance{elType}(d, q)
end
Expand Down
18 changes: 17 additions & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@ function calc_H!(H, integ, cache)
calc_H_EK0!(H, integ, cache)
# @assert integ.u == @view x_pred.μ[1:(q+1):end]
OrdinaryDiffEq.calc_J!(ddu, integ, cache, true)
ProbNumDiffEq._matmul!(H, view(ddu, 1:d, :), cache.SolProj, -1.0, 1.0)
_matmul!(H, view(ddu, 1:d, :), cache.SolProj, -1.0, 1.0)
elseif integ.alg isa DiagonalEK1
calc_H_EK0!(H, integ, cache)
# @assert integ.u == @view x_pred.μ[1:(q+1):end]
# ddu_full = Matrix(ddu)
# @info "ddu" ddu_full
# error()
OrdinaryDiffEq.calc_J!(ddu, integ, cache, true)

@unpack C_dxd = cache
@simd ivdep for i in eachindex(blocks(C_dxd))
@assert length(C_dxd.blocks[i]) == 1
C_dxd.blocks[i][1] = ddu[i, i]
end
_matmul!(H, C_dxd, cache.SolProj, -1.0, 1.0)
else
error("Unknown algorithm")
end
return nothing
end
Expand Down

0 comments on commit 5519c31

Please sign in to comment.