Skip to content

Commit

Permalink
Building on OrdinaryDiffEq to compute Jacobians (#89)
Browse files Browse the repository at this point in the history
* Building on OrdinaryDiffEq to compute Jacobians

efficiently, and with a switch between forward and finite diff!

* Fix IIP / OOP related problems

* Add a Jacobian computation-related test

* JuliaFormatter.jl

* Bump version number
  • Loading branch information
nathanaelbosch authored Nov 4, 2021
1 parent b9ad0ad commit dfd46c8
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ProbNumDiffEq"
uuid = "bf3e78b0-7d74-48a5-b855-9609533b56a5"
authors = ["Nathanael Bosch"]
version = "0.3.2"
version = "0.4.0"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
############################################################################################

OrdinaryDiffEq.alg_autodiff(alg::AbstractEK) = true
OrdinaryDiffEq.alg_autodiff(alg::EK1{CS,AD}) where {CS,AD} = AD
OrdinaryDiffEq.alg_difftype(alg::EK1{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType
@inline DiffEqBase.get_tmp_cache(integ, alg::EK1, cache) = (cache.tmp, cache.atmp)
OrdinaryDiffEq.get_chunksize(alg::AbstractEK) = Val(0)
OrdinaryDiffEq.isfsal(alg::AbstractEK) = false

Expand Down
28 changes: 22 additions & 6 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,29 @@ See also: [`EK0`](@ref)
- N. Bosch, P. Hennig, F. Tronarp: **Calibrated Adaptive Probabilistic ODE Solvers** (2021)
- F. Tronarp, H. Kersting, S. Särkkä, and P. Hennig: **Probabilistic Solutions To Ordinary Differential Equations As Non-Linear Bayesian Filtering: A New Perspective** (2019)
"""
Base.@kwdef struct EK1{IT} <: AbstractEK
prior::Symbol = :ibm
order::Int = 3
diffusionmodel::Symbol = :dynamic
smooth::Bool = true
initialization::IT = TaylorModeInit()
struct EK1{CS,AD,DiffType,IT} <: AbstractEK
prior::Symbol
order::Int
diffusionmodel::Symbol
smooth::Bool
initialization::IT
end
EK1(;
prior=:ibm,
order=3,
diffusionmodel=:dynamic,
smooth=true,
initialization=TaylorModeInit(),
chunk_size=0,
autodiff=true,
diff_type=Val{:forward},
) = EK1{chunk_size,autodiff,diff_type,typeof(initialization)}(
prior,
order,
diffusionmodel,
smooth,
initialization,
)

Base.@kwdef struct EK1FDB{IT} <: AbstractEK
prior::Symbol = :ibm
Expand Down
30 changes: 28 additions & 2 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ mutable struct GaussianODEFilterCache{
puType,
llType,
CType,
rateType,
UF,
JC,
uNoUnitsType,
} <: ODEFiltersCache
# Constants
d::Int # Dimension of the problem
Expand All @@ -44,6 +48,7 @@ mutable struct GaussianODEFilterCache{
u_pred::uType
u_filt::uType
tmp::uType
atmp::uNoUnitsType
x::xType
x_pred::xType
x_filt::xType
Expand All @@ -66,6 +71,9 @@ mutable struct GaussianODEFilterCache{
log_likelihood::llType
C1::CType
C2::CType
du1::rateType
uf::UF
jac_config::JC
end

function OrdinaryDiffEq.alg_cache(
Expand All @@ -83,8 +91,8 @@ function OrdinaryDiffEq.alg_cache(
reltol,
p,
calck,
IIP,
)
::Val{IIP},
) where {IIP}
initialize_derivatives = true

if u isa Number
Expand Down Expand Up @@ -174,6 +182,16 @@ function OrdinaryDiffEq.alg_cache(
K2 = similar(K)
G2 = similar(G)
err_tmp = similar(du)

# Things for calc_J
uf =
IIP == true ? OrdinaryDiffEq.UJacobianWrapper(f, t, p) :
OrdinaryDiffEq.UDerivativeWrapper(f, t, p)
du1 = similar(rate_prototype)
dw1 = zero(u)
jac_config = OrdinaryDiffEq.build_jac_config(alg, f, uf, du1, uprev, u, tmp, dw1)
atmp = similar(u, uEltypeNoUnits)

return GaussianODEFilterCache{
typeof(R),
typeof(Proj),
Expand All @@ -193,6 +211,10 @@ function OrdinaryDiffEq.alg_cache(
typeof(pu_tmp),
uEltypeNoUnits,
typeof(C1),
typeof(du1),
typeof(uf),
typeof(jac_config),
typeof(atmp),
}(
# Constants
d,
Expand All @@ -215,6 +237,7 @@ function OrdinaryDiffEq.alg_cache(
u_pred,
u_filt,
tmp,
atmp,
x0,
x_pred,
x_filt,
Expand All @@ -237,5 +260,8 @@ function OrdinaryDiffEq.alg_cache(
zero(uEltypeNoUnits),
C1,
C2,
du1,
uf,
jac_config,
)
end
15 changes: 10 additions & 5 deletions src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,17 @@ function evaluate_ode!(
# Jacobian is computed either with the given jac, or ForwardDiff
if !isnothing(f.jac)
_eval_f_jac!(ddu, u_lin, p, t, f)
elseif isinplace(f)
ForwardDiff.jacobian!(ddu, (du, u) -> f(du, u, p, t), du, u_lin)
integ.destats.nf += 1
else
ddu .= ForwardDiff.jacobian(u -> f(u, p, t), u_lin)
integ.destats.nf += 1
!isnothing(f.jac)
@unpack du1, uf, jac_config = integ.cache
uf.f = OrdinaryDiffEq.nlsolve_f(f, alg)
uf.t = t
uf.p = p
if isinplace(f)
OrdinaryDiffEq.jacobian!(ddu, uf, u_lin, du1, integ, jac_config)
else
ddu .= OrdinaryDiffEq.jacobian(uf, u_lin, integ)
end
end
integ.destats.njacs += 1

Expand Down
18 changes: 17 additions & 1 deletion test/specific_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ end
end

@testset "OOP problem" begin
f(u, p, t) = ([p[1] * u[1] .* (1 .- u[1])])
f(u, p, t) = p .* u .* (1 .- u)
prob = ODEProblem(f, [1e-1], (0.0, 5), [3.0])
@testset "without jacobian" begin
# first without defined jac
Expand Down Expand Up @@ -211,3 +211,19 @@ end
sol2 = solve(prob, RadauIIA5())
@test sol1[end] sol2[end] rtol = 1e-5
end

@testset "EK1 Jacobian computation" begin
prob = prob_ode_fitzhughnagumo
@assert isnothing(prob.f.jac)

# make sure that the kwarg works
sol1 = solve(prob, EK1())
sol2 = solve(prob, EK1(autodiff=false))
@test sol2 isa ProbNumDiffEq.ProbODESolution

# check that forwarddiff leads to a smaller nf than finite diff
@test sol1.destats.nf < sol2.destats.nf

# use the EK1 on a non-forwarddiffable function
# TODO
end

2 comments on commit dfd46c8

@nathanaelbosch
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48176

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" dfd46c8b471e69d2eafe81daad898e441dc82c83
git push origin v0.4.0

Please sign in to comment.