diff --git a/Project.toml b/Project.toml index 69bd3e23b..1c45dde83 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ProbNumDiffEq" uuid = "bf3e78b0-7d74-48a5-b855-9609533b56a5" authors = ["Nathanael Bosch"] -version = "0.3.2" +version = "0.4.0" [deps] DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/src/alg_utils.jl b/src/alg_utils.jl index d036553c0..57e57bb16 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -4,6 +4,9 @@ ############################################################################################ OrdinaryDiffEq.alg_autodiff(alg::AbstractEK) = true +OrdinaryDiffEq.alg_autodiff(alg::EK1{CS,AD}) where {CS,AD} = AD +OrdinaryDiffEq.alg_difftype(alg::EK1{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType +@inline DiffEqBase.get_tmp_cache(integ, alg::EK1, cache) = (cache.tmp, cache.atmp) OrdinaryDiffEq.get_chunksize(alg::AbstractEK) = Val(0) OrdinaryDiffEq.isfsal(alg::AbstractEK) = false diff --git a/src/algorithms.jl b/src/algorithms.jl index 05f4777be..af379b4d7 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -43,13 +43,29 @@ See also: [`EK0`](@ref) - N. Bosch, P. Hennig, F. Tronarp: **Calibrated Adaptive Probabilistic ODE Solvers** (2021) - F. Tronarp, H. Kersting, S. Särkkä, and P. Hennig: **Probabilistic Solutions To Ordinary Differential Equations As Non-Linear Bayesian Filtering: A New Perspective** (2019) """ -Base.@kwdef struct EK1{IT} <: AbstractEK - prior::Symbol = :ibm - order::Int = 3 - diffusionmodel::Symbol = :dynamic - smooth::Bool = true - initialization::IT = TaylorModeInit() +struct EK1{CS,AD,DiffType,IT} <: AbstractEK + prior::Symbol + order::Int + diffusionmodel::Symbol + smooth::Bool + initialization::IT end +EK1(; + prior=:ibm, + order=3, + diffusionmodel=:dynamic, + smooth=true, + initialization=TaylorModeInit(), + chunk_size=0, + autodiff=true, + diff_type=Val{:forward}, +) = EK1{chunk_size,autodiff,diff_type,typeof(initialization)}( + prior, + order, + diffusionmodel, + smooth, + initialization, +) Base.@kwdef struct EK1FDB{IT} <: AbstractEK prior::Symbol = :ibm diff --git a/src/caches.jl b/src/caches.jl index 64977e91c..31410defe 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -21,6 +21,10 @@ mutable struct GaussianODEFilterCache{ puType, llType, CType, + rateType, + UF, + JC, + uNoUnitsType, } <: ODEFiltersCache # Constants d::Int # Dimension of the problem @@ -44,6 +48,7 @@ mutable struct GaussianODEFilterCache{ u_pred::uType u_filt::uType tmp::uType + atmp::uNoUnitsType x::xType x_pred::xType x_filt::xType @@ -66,6 +71,9 @@ mutable struct GaussianODEFilterCache{ log_likelihood::llType C1::CType C2::CType + du1::rateType + uf::UF + jac_config::JC end function OrdinaryDiffEq.alg_cache( @@ -83,8 +91,8 @@ function OrdinaryDiffEq.alg_cache( reltol, p, calck, - IIP, -) + ::Val{IIP}, +) where {IIP} initialize_derivatives = true if u isa Number @@ -174,6 +182,16 @@ function OrdinaryDiffEq.alg_cache( K2 = similar(K) G2 = similar(G) err_tmp = similar(du) + + # Things for calc_J + uf = + IIP == true ? OrdinaryDiffEq.UJacobianWrapper(f, t, p) : + OrdinaryDiffEq.UDerivativeWrapper(f, t, p) + du1 = similar(rate_prototype) + dw1 = zero(u) + jac_config = OrdinaryDiffEq.build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1) + atmp = similar(u, uEltypeNoUnits) + return GaussianODEFilterCache{ typeof(R), typeof(Proj), @@ -193,6 +211,10 @@ function OrdinaryDiffEq.alg_cache( typeof(pu_tmp), uEltypeNoUnits, typeof(C1), + typeof(du1), + typeof(uf), + typeof(jac_config), + typeof(atmp), }( # Constants d, @@ -215,6 +237,7 @@ function OrdinaryDiffEq.alg_cache( u_pred, u_filt, tmp, + atmp, x0, x_pred, x_filt, @@ -237,5 +260,8 @@ function OrdinaryDiffEq.alg_cache( zero(uEltypeNoUnits), C1, C2, + du1, + uf, + jac_config, ) end diff --git a/src/perform_step.jl b/src/perform_step.jl index 921dacd2d..b83082591 100644 --- a/src/perform_step.jl +++ b/src/perform_step.jl @@ -174,12 +174,17 @@ function evaluate_ode!( # Jacobian is computed either with the given jac, or ForwardDiff if !isnothing(f.jac) _eval_f_jac!(ddu, u_lin, p, t, f) - elseif isinplace(f) - ForwardDiff.jacobian!(ddu, (du, u) -> f(du, u, p, t), du, u_lin) - integ.destats.nf += 1 else - ddu .= ForwardDiff.jacobian(u -> f(u, p, t), u_lin) - integ.destats.nf += 1 + !isnothing(f.jac) + @unpack du1, uf, jac_config = integ.cache + uf.f = OrdinaryDiffEq.nlsolve_f(f, alg) + uf.t = t + uf.p = p + if isinplace(f) + OrdinaryDiffEq.jacobian!(ddu, uf, u_lin, du1, integ, jac_config) + else + ddu .= OrdinaryDiffEq.jacobian(uf, u_lin, integ) + end end integ.destats.njacs += 1 diff --git a/test/specific_problems.jl b/test/specific_problems.jl index 61e171dc9..efd325cd8 100644 --- a/test/specific_problems.jl +++ b/test/specific_problems.jl @@ -92,7 +92,7 @@ end end @testset "OOP problem" begin - f(u, p, t) = ([p[1] * u[1] .* (1 .- u[1])]) + f(u, p, t) = p .* u .* (1 .- u) prob = ODEProblem(f, [1e-1], (0.0, 5), [3.0]) @testset "without jacobian" begin # first without defined jac @@ -211,3 +211,19 @@ end sol2 = solve(prob, RadauIIA5()) @test sol1[end] ≈ sol2[end] rtol = 1e-5 end + +@testset "EK1 Jacobian computation" begin + prob = prob_ode_fitzhughnagumo + @assert isnothing(prob.f.jac) + + # make sure that the kwarg works + sol1 = solve(prob, EK1()) + sol2 = solve(prob, EK1(autodiff=false)) + @test sol2 isa ProbNumDiffEq.ProbODESolution + + # check that forwarddiff leads to a smaller nf than finite diff + @test sol1.destats.nf < sol2.destats.nf + + # use the EK1 on a non-forwarddiffable function + # TODO +end