Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BAOAB algorithm #397

Merged
merged 6 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ module StochasticDiffEq
include("caches/predcorr_caches.jl")
include("caches/SROCK_caches.jl")
include("caches/tau_caches.jl")
include("caches/dynamical_caches.jl")
include("integrators/type.jl")
include("dense.jl")
include("alg_utils.jl")
Expand All @@ -110,6 +111,7 @@ module StochasticDiffEq
include("perform_step/composite.jl")
include("perform_step/SROCK_perform_step.jl")
include("perform_step/tau_leaping.jl")
include("perform_step/dynamical.jl")
include("tableaus.jl")
include("SROCK_tableaus.jl")
include("iterated_integrals.jl")
Expand Down Expand Up @@ -145,6 +147,8 @@ module StochasticDiffEq

export TauLeaping, CaoTauLeaping

export BAOAB

export StochasticDiffEqRODEAlgorithm, StochasticDiffEqRODEAdaptiveAlgorithm,
StochasticDiffEqRODECompositeAlgorithm

Expand Down
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ alg_order(alg::SMEB) = 1//1
alg_order(alg::TauLeaping) = 1//1
alg_order(alg::CaoTauLeaping) = 1//1

alg_order(alg::BAOAB) = 2//1

alg_order(alg::SKenCarp) = 2//1
alg_order(alg::Union{StochasticDiffEqCompositeAlgorithm,StochasticDiffEqRODECompositeAlgorithm}) = maximum(alg_order.(alg.algs))
get_current_alg_order(alg::StochasticDiffEqAlgorithm,cache) = alg_order(alg)
Expand Down Expand Up @@ -196,6 +198,7 @@ alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::RKMilGeneral) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::IIF1M) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::IIF2M) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::Union{StochasticDiffEqCompositeAlgorithm,StochasticDiffEqRODECompositeAlgorithm}) = max((alg_compatible(prob,a) for a in alg.algs)...)
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::BAOAB) = is_diagonal_noise(prob)

function alg_compatible(prob::JumpProblem,alg::Union{StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpAlgorithm})
prob.prob isa DiscreteProblem
Expand Down
15 changes: 15 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,18 @@ end
struct RandomEM <: StochasticDiffEqRODEAlgorithm end

const SplitSDEAlgorithms = Union{IIF1M,IIF2M,IIF1Mil,SKenCarp,SplitEM}

@doc raw"""
Leimkuhler B., Matthews C., Robust and efficient configurational molecular sampling via
Langevin dynamics, J. Chem. Phys. 138, 174102 (2013)
DOI:10.1063/1.4802990

```math
du = vdt \\
dv = f(v,u) dt - \gamma v dt + g(u) dW
```
"""
struct BAOAB{T} <: StochasticDiffEqAlgorithm
gamma::T
end
BAOAB(;gamma=1.0) = BAOAB(gamma)
39 changes: 39 additions & 0 deletions src/caches/dynamical_caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

struct BAOABConstantCache{uType,uEltypeNoUnits} <: StochasticDiffEqConstantCache
k::uType
half::uEltypeNoUnits
c1::uEltypeNoUnits
c2::uEltypeNoUnits
end
@cache struct BAOABCache{uType,uEltypeNoUnits,rateNoiseType} <: StochasticDiffEqMutableCache
utmp::uType
dutmp::uType
k::uType
gtmp::uType
noise::rateNoiseType
half::uEltypeNoUnits
c1::uEltypeNoUnits
c2::uEltypeNoUnits
end

function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{false}})
k = zero(rate_prototype.x[1])
c1 = exp(-alg.gamma*dt)
c2 = sqrt(1 - c1^2) / sqrt(2alg.gamma)
BAOABConstantCache(k, uEltypeNoUnits(1//2), uEltypeNoUnits(c1), uEltypeNoUnits(c2))
end

function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{true}})
dutmp = zero(u.x[1])
utmp = zero(u.x[2])
k = zero(rate_prototype.x[1])

gtmp = zero(rate_prototype.x[1])
noise = zero(rate_prototype.x[1])

half = uEltypeNoUnits(1//2)
c1 = exp(-alg.gamma*dt)
c2 = sqrt(1 - c1^2) / sqrt(2alg.gamma)

BAOABCache(utmp, dutmp, k, gtmp, noise, half, uEltypeNoUnits(c1), uEltypeNoUnits(c2))
end
81 changes: 81 additions & 0 deletions src/perform_step/dynamical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
function verify_f2(f, p, q, pa, t, integrator, ::BAOABConstantCache)
res = f(p, q, pa, t)
res != p && throwex(integrator)
end
function verify_f2(f, res, p, q, pa, t, integrator, ::BAOABCache)
f(res, p, q, pa, t)
res != p && throwex(integrator)
end
function throwex(integrator)
algn = typeof(integrator.alg)
throw(ArgumentError("Algorithm $algn is not applicable if f2(p, q, t) != p"))
end

function initialize!(integrator, cache::BAOABConstantCache)
@unpack t,dt,uprev,u,p,W = integrator
du1 = integrator.uprev.x[1]
u1 = integrator.uprev.x[2]

verify_f2(integrator.f.f2, du1, u1, p, t, integrator, cache)
cache.k .= integrator.f.f1(du1,u1,p,t)
end

function initialize!(integrator, cache::BAOABCache)
@unpack t,dt,uprev,u,p,W = integrator
du1 = integrator.uprev.x[1]
u1 = integrator.uprev.x[2]

verify_f2(integrator.f.f2, cache.k, du1, u1, p, t, integrator, cache)
integrator.f.f1(cache.k,du1,u1,p,t)
end

@muladd function perform_step!(integrator,cache::BAOABConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,p,W = integrator
@unpack k, half, c1, c2 = cache
du1 = uprev.x[1]
u1 = uprev.x[2]

# B
du2 = du1 + half*dt*k

# A
u2 = u1 + half*dt*du2

# O
noise = integrator.g(u2,p,t+dt*half).*W.dW
du3 = c1*du2 + c2*noise

# A
u = u2 + half*dt*du3

# B
k .= f.f1(du3,u,p,t+dt)
du = du3 + half*dt*k

integrator.u = ArrayPartition((du, u))
end

@muladd function perform_step!(integrator,cache::BAOABCache,f=integrator.f)
@unpack t,dt,uprev,u,p,W = integrator
@unpack utmp, dutmp, k, gtmp, noise, half, c1, c2 = cache
du1 = uprev.x[1]
u1 = uprev.x[2]

# B
@.. dutmp = du1 + half*dt*k

# A
@.. utmp = u1 + half*dt*dutmp

# O
integrator.g(gtmp,utmp,p,t+dt*half)
@.. noise = gtmp*W.dW
@.. dutmp = c1*dutmp + c2*noise

# A
@.. u.x[2] = utmp + half*dt*dutmp

# B
f.f1(k,dutmp,u.x[2],p,t+dt)
@.. u.x[1] = dutmp + half*dt*k
end
8 changes: 6 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ function DiffEqBase.__init(
end
rateType = typeof(rate_prototype) ## Can be different if united

if is_diagonal_noise(prob)
if prob.f isa DynamicalSDEFunction
noise_rate_prototype = rate_prototype.x[1]
elseif is_diagonal_noise(prob)
noise_rate_prototype = rate_prototype
elseif prob isa DiffEqBase.AbstractRODEProblem
if prob isa DiffEqBase.AbstractSDEProblem
Expand Down Expand Up @@ -273,7 +275,9 @@ function DiffEqBase.__init(
randType = typeof(rand_prototype)
else
randElType = uBottomEltypeNoUnits # Strip units and type info
if is_diagonal_noise(prob)
if prob.f isa DynamicalSDEFunction
rand_prototype = copy(noise_rate_prototype)
elseif is_diagonal_noise(prob)
if typeof(u) <: SArray
rand_prototype = zero(u) # TODO: Array{randElType} for units
else
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")

if !is_APPVEYOR && (GROUP == "All" || GROUP == "AlgConvergence")
@time @safetestset "Convergence Tests" begin include("sde/sde_convergence_tests.jl") end
@time @safetestset "Dynamical SDE Tests" begin include("sde/sde_dynamical.jl") end
end

if !is_APPVEYOR && GROUP == "AlgConvergence2"
Expand Down
31 changes: 31 additions & 0 deletions test/sde/sde_dynamical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using StochasticDiffEq, DiffEqNoiseProcess, Test, DiffEqDevTools, Random
Random.seed!(1)

u0 = zeros(2)
v0 = ones(2)
γ = 50

f1_harmonic(v,u,p,t) = -u
f2_harmonic(v,u,p,t) = v
g(u,p,t) = sqrt(2γ)

f1_harmonic_iip(dv,v,u,p,t) = dv .= f1_harmonic(v,u,p,t)
f2_harmonic_iip(du,v,u,p,t) = du .= f2_harmonic(v,u,p,t)
g_iip(du,u,p,t) = du .= g(u,p,t)

ff_harmonic = DynamicalSDEFunction(f1_harmonic,f2_harmonic,g)
prob1 = DynamicalSDEProblem(ff_harmonic,g,v0,u0,(0.0,5.0))
sol1 = solve(prob1,BAOAB(gamma=γ);dt=1/10,save_noise=true)

prob2 = DynamicalSDEProblem(f1_harmonic_iip,f2_harmonic_iip,g_iip,v0,u0,(0.0,5.0); noise=NoiseWrapper(sol1.W))
sol2 = solve(prob2,BAOAB(gamma=γ);dt=1/10)

@test sol1[:] ≈ sol2[:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should get convergence tests


dts = (1/2) .^ (9:-1:3)
# Can't use NoiseGrid as noise is not generated with the correct size in convergence.jl. We require noise with shape of v.
# I don't think there are any analytic solutions for problems of this type
sim1 = analyticless_test_convergence(dts,prob1,BAOAB(gamma=γ),(1/2)^10;trajectories=Int(1e2),use_noise_grid=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely not enough trajectories to produce a weak convergence estimate. Do a local plot with like 5e5 trajectories and show me the plot(sim1).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errors

Here's the plot with 5e4 trajectories with tspan=(0.0,0.5). I seemed to be running into memory issues if I try more trajectories or a longer tspan. However, it does seem to be converged now. I also tried with 1e4 trajectories and the plot is only marginally different.

Here's the estimate I obtain:

julia> sim1.𝒪est
Dict{Any,Any} with 2 entries:
  :weak_final => 1.45594
  :final      => 0.802602

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you make the dts a bit smaller? the strong convergence estimate seems to not entire its asymtopic regime until half way through.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errors

This is with only 1e4 trajectories as it takes a very long time with smaller timesteps. The test dt here is 0.5^12.

julia> sim1.𝒪est
Dict{Any,Any} with 2 entries:
  :weak_final => 1.19034
  :final      => 0.777283

This doesn't quite match up with the previous plot however so I guess it's not properly converged. I'm not sure I can run it for many more trajectories though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More trajectories won't help. You'd see that because it would flatline in a noisy fashion on the left end due to not having enough values for the sampling error. This doesn't show sampling eror. It shows 1st order convergence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lower dts would be the check to get rid of that wobble though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, does this mean the algorithm is 1st order then? I don't know much about SDEs, I was just going off the plots in the paper, but that's converging observables calculated from the distributions obtained from the integration. Is that different?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of the convergence diagrams aren't well-behaved in their paper either. Can you try and recreate one and see if it's around the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out I had made a mistake in the algorithm which I've now fixed. The convergence is a lot more well behaved now and I was able to roughly reproduce one of their plots.
image
image

I did 2000 trajectories with 1e6 steps and they did 2000 with 1e8 so I think it's safe to assume I'd converge to the same result given enough timesteps.

@test abs(sim1.𝒪est[:weak_final]-1) < 0.3 # Gives order of only 1? Should be weak order 2.
sim2 = analyticless_test_convergence(dts,prob2,BAOAB(gamma=γ),(1/2)^10;trajectories=Int(1e2),use_noise_grid=false)
@test abs(sim2.𝒪est[:weak_final]-1) < 0.3 # Gives order of only 1? Should be weak order 2.