Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BAOAB algorithm #397

Merged
merged 6 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ module StochasticDiffEq
include("caches/predcorr_caches.jl")
include("caches/SROCK_caches.jl")
include("caches/tau_caches.jl")
include("caches/dynamical_caches.jl")
include("integrators/type.jl")
include("dense.jl")
include("alg_utils.jl")
Expand All @@ -110,6 +111,7 @@ module StochasticDiffEq
include("perform_step/composite.jl")
include("perform_step/SROCK_perform_step.jl")
include("perform_step/tau_leaping.jl")
include("perform_step/dynamical.jl")
include("tableaus.jl")
include("SROCK_tableaus.jl")
include("iterated_integrals.jl")
Expand Down Expand Up @@ -145,6 +147,8 @@ module StochasticDiffEq

export TauLeaping, CaoTauLeaping

export BAOAB

export StochasticDiffEqRODEAlgorithm, StochasticDiffEqRODEAdaptiveAlgorithm,
StochasticDiffEqRODECompositeAlgorithm

Expand Down
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ alg_order(alg::SMEB) = 1//1
alg_order(alg::TauLeaping) = 1//1
alg_order(alg::CaoTauLeaping) = 1//1

alg_order(alg::BAOAB) = 1//1 # Not sure what order it is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be good to look into


alg_order(alg::SKenCarp) = 2//1
alg_order(alg::Union{StochasticDiffEqCompositeAlgorithm,StochasticDiffEqRODECompositeAlgorithm}) = maximum(alg_order.(alg.algs))
get_current_alg_order(alg::StochasticDiffEqAlgorithm,cache) = alg_order(alg)
Expand Down Expand Up @@ -196,6 +198,7 @@ alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::RKMilGeneral) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::IIF1M) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::IIF2M) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::Union{StochasticDiffEqCompositeAlgorithm,StochasticDiffEqRODECompositeAlgorithm}) = max((alg_compatible(prob,a) for a in alg.algs)...)
alg_compatible(prob::DiffEqBase.AbstractSDEProblem,alg::BAOAB) = is_diagonal_noise(prob)

function alg_compatible(prob::JumpProblem,alg::Union{StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpAlgorithm})
prob.prob isa DiscreteProblem
Expand Down
10 changes: 10 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,13 @@ end
struct RandomEM <: StochasticDiffEqRODEAlgorithm end

const SplitSDEAlgorithms = Union{IIF1M,IIF2M,IIF1Mil,SKenCarp,SplitEM}

"""
Leimkuhler B., Matthews C., Robust and efficient configurational molecular sampling via
Langevin dynamics, J. Chem. Phys. 138, 174102 (2013)
DOI:10.1063/1.4802990
"""
struct BAOAB{T} <: StochasticDiffEqAlgorithm
gamma::T
end
BAOAB(;gamma=1.0) = BAOAB(gamma)
39 changes: 39 additions & 0 deletions src/caches/dynamical_caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

struct BAOABConstantCache{uType,uEltypeNoUnits} <: StochasticDiffEqConstantCache
k::uType
half::uEltypeNoUnits
c1::uEltypeNoUnits
c2::uEltypeNoUnits
end
@cache struct BAOABCache{uType,uEltypeNoUnits,rateNoiseType} <: StochasticDiffEqMutableCache
utmp::uType
dutmp::uType
k::uType
gtmp::uType
noise::rateNoiseType
half::uEltypeNoUnits
c1::uEltypeNoUnits
c2::uEltypeNoUnits
end

function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{false}})
k = zero(rate_prototype.x[1])
c1 = exp(-alg.gamma*dt)
c2 = sqrt(1 - c1^2)
BAOABConstantCache(k, uEltypeNoUnits(1//2), uEltypeNoUnits(c1), uEltypeNoUnits(c2))
end

function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{true}})
dutmp = zero(u.x[1])
utmp = zero(u.x[2])
k = zero(rate_prototype.x[1])

gtmp = zero(rate_prototype.x[1])
noise = zero(rate_prototype.x[1])

half = uEltypeNoUnits(1//2)
c1 = exp(-alg.gamma*dt)
c2 = sqrt(1 - c1^2)

BAOABCache(utmp, dutmp, k, gtmp, noise, half, uEltypeNoUnits(c1), uEltypeNoUnits(c2))
end
81 changes: 81 additions & 0 deletions src/perform_step/dynamical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
function verify_f2(f, p, q, pa, t, integrator, ::BAOABConstantCache)
res = f(p, q, pa, t)
res != p && throwex(integrator)
end
function verify_f2(f, res, p, q, pa, t, integrator, ::BAOABCache)
f(res, p, q, pa, t)
res != p && throwex(integrator)
end
function throwex(integrator)
algn = typeof(integrator.alg)
throw(ArgumentError("Algorithm $algn is not applicable if f2(p, q, t) != p"))
end

function initialize!(integrator, cache::BAOABConstantCache)
@unpack t,dt,uprev,u,p,W = integrator
du1 = integrator.uprev.x[1]
u1 = integrator.uprev.x[2]

verify_f2(integrator.f.f2, du1, u1, p, t, integrator, cache)
cache.k .= integrator.f.f1(du1,u1,p,t)
end

function initialize!(integrator, cache::BAOABCache)
@unpack t,dt,uprev,u,p,W = integrator
du1 = integrator.uprev.x[1]
u1 = integrator.uprev.x[2]

verify_f2(integrator.f.f2, cache.k, du1, u1, p, t, integrator, cache)
integrator.f.f1(cache.k,du1,u1,p,t)
end

@muladd function perform_step!(integrator,cache::BAOABConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,p,W = integrator
@unpack k, half, c1, c2 = cache
du1 = uprev.x[1]
u1 = uprev.x[2]

# B
du2 = du1 + half*dt*k

# A
u2 = u1 + half*dt*du2

# O
noise = integrator.g(u2,p,t+dt*half).*W.dW
du3 = c1*du2 + c2*noise

# A
u = u2 + half*dt*du3

# B
k .= f.f1(du3,u,p,t+dt)
du = du3 + half*dt*k

integrator.u = ArrayPartition((du, u))
end

@muladd function perform_step!(integrator,cache::BAOABCache,f=integrator.f)
@unpack t,dt,uprev,u,p,W = integrator
@unpack utmp, dutmp, k, gtmp, noise, half, c1, c2 = cache
du1 = uprev.x[1]
u1 = uprev.x[2]

# B
@.. dutmp = du1 + half*dt*k

# A
@.. utmp = u1 + half*dt*dutmp

# O
integrator.g(gtmp,utmp,p,t+dt*half)
@.. noise = gtmp*W.dW
@.. dutmp = c1*dutmp + c2*noise

# A
@.. u.x[2] = utmp + half*dt*dutmp

# B
f.f1(k,dutmp,u.x[2],p,t+dt)
@.. u.x[1] = dutmp + half*dt*k
end
8 changes: 6 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ function DiffEqBase.__init(
end
rateType = typeof(rate_prototype) ## Can be different if united

if is_diagonal_noise(prob)
if prob.f isa DynamicalSDEFunction
noise_rate_prototype = rate_prototype.x[1]
elseif is_diagonal_noise(prob)
noise_rate_prototype = rate_prototype
elseif prob isa DiffEqBase.AbstractRODEProblem
if prob isa DiffEqBase.AbstractSDEProblem
Expand Down Expand Up @@ -273,7 +275,9 @@ function DiffEqBase.__init(
randType = typeof(rand_prototype)
else
randElType = uBottomEltypeNoUnits # Strip units and type info
if is_diagonal_noise(prob)
if prob.f isa DynamicalSDEFunction
rand_prototype = copy(noise_rate_prototype)
elseif is_diagonal_noise(prob)
if typeof(u) <: SArray
rand_prototype = zero(u) # TODO: Array{randElType} for units
else
Expand Down
25 changes: 25 additions & 0 deletions test/sde/sde_dynamical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

jamesgardner1421 marked this conversation as resolved.
Show resolved Hide resolved
using StochasticDiffEq, DiffEqNoiseProcess, Test, DiffEqDevTools
jamesgardner1421 marked this conversation as resolved.
Show resolved Hide resolved
using Plots
using RecursiveArrayTools

u0 = zeros(2)
v0 = ones(2)
γ = 1
f1_harmonic(v,u,p,t) = -u
f2_harmonic(v,u,p,t) = v
g(u,p,t) = 0.2

ff_harmonic = DynamicalSDEFunction(f1_harmonic,f2_harmonic,g)
prob = DynamicalSDEProblem(ff_harmonic,g,v0,u0,(0.0,5.0))

sol1 = solve(prob,BAOAB(gamma=γ);dt=1/10,save_noise=true)

f1_harmonic_iip(dv,v,u,p,t) = dv .= f1_harmonic(v,u,p,t)
f2_harmonic_iip(du,v,u,p,t) = du .= f2_harmonic(v,u,p,t)
g_iip(du,u,p,t) = du .= g(u,p,t)

prob = DynamicalSDEProblem(f1_harmonic_iip,f2_harmonic_iip,g_iip,v0,u0,(0.0,5.0); noise=NoiseWrapper(sol1.W))

sol2 = solve(prob,BAOAB(gamma=γ);dt=1/10)
@test sol1[:] ≈ sol2[:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should get convergence tests