Skip to content

Commit

Permalink
Merge pull request #571 from SciML/W2Ito1
Browse files Browse the repository at this point in the history
W2Ito1 scheme
  • Loading branch information
ChrisRackauckas authored Nov 4, 2024
2 parents 07bd47f + 1ef01f8 commit 444328d
Show file tree
Hide file tree
Showing 23 changed files with 729 additions and 38 deletions.
8 changes: 4 additions & 4 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ steps:
GROUP: "{{matrix}}"
plugins:
- JuliaCI/julia#v1:
version: "1"
version: "1.10"
- JuliaCI/julia-test#v1:
coverage: false
julia_args: "--threads=auto"
Expand All @@ -32,7 +32,7 @@ steps:
GROUP: "{{matrix}}"
plugins:
- JuliaCI/julia#v1:
version: "1"
version: "1.10"
- JuliaCI/julia-test#v1:
coverage: false
julia_args: "--threads=auto"
Expand All @@ -53,7 +53,7 @@ steps:
GROUP: "{{matrix}}"
plugins:
- JuliaCI/julia#v1:
version: "1"
version: "1.10"
- JuliaCI/julia-test#v1:
coverage: false
julia_args: "--threads=auto"
Expand All @@ -68,7 +68,7 @@ steps:
- label: "WeakAdaptiveGPU"
plugins:
- JuliaCI/julia#v1:
version: "1"
version: "1.10"
- JuliaCI/julia-test#v1:
coverage: false
agents:
Expand Down
2 changes: 1 addition & 1 deletion src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ end
SROCK1, SROCK2, SROCKEM, SKSROCK, TangXiaoSROCK2, KomBurSROCK2, SROCKC2,
WangLi3SMil_A, WangLi3SMil_B, WangLi3SMil_C, WangLi3SMil_D, WangLi3SMil_E, WangLi3SMil_F,
AutoSOSRI2, AutoSOSRA2,
DRI1, DRI1NM, RI1, RI3, RI5, RI6, RDI1WM, RDI2WM, RDI3WM, RDI4WM,
DRI1, DRI1NM, RI1, RI3, RI5, RI6, RDI1WM, RDI2WM, RDI3WM, RDI4WM, W2Ito1,
RS1, RS2,
PL1WM, PL1WMA,
NON, COM, NON2
Expand Down
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ alg_order(alg::RDI1WM) = 1 // 1
alg_order(alg::RDI2WM) = 1 // 1
alg_order(alg::RDI3WM) = 1 // 1
alg_order(alg::RDI4WM) = 1 // 1
alg_order(alg::W2Ito1) = 1 // 1

alg_order(alg::RS1) = 1 // 1
alg_order(alg::RS2) = 1 // 1
Expand Down Expand Up @@ -181,6 +182,7 @@ alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI1WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI2WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI3WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI4WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::W2Ito1) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RS1) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RS2) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::PL1WM) = true
Expand Down Expand Up @@ -254,6 +256,7 @@ alg_needs_extra_process(alg::RDI1WM) = true
alg_needs_extra_process(alg::RDI2WM) = true
alg_needs_extra_process(alg::RDI3WM) = true
alg_needs_extra_process(alg::RDI4WM) = true
alg_needs_extra_process(alg::W2Ito1) = true
alg_needs_extra_process(alg::RS1) = true
alg_needs_extra_process(alg::RS2) = true
alg_needs_extra_process(alg::PL1WM) = true
Expand Down
13 changes: 13 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ Can handle diagonal, non-diagonal, non-commuting, and scalar additive noise.
"""
struct RDI4WM <: StochasticDiffEqAdaptiveAlgorithm end


"""
Tang, X., & Xiao, A., Efficient weak second-order stochastic Runge–Kutta methods
for Itô stochastic differential equations,
BIT Numerical Mathematics, 57, 241-260 (2017)
DOI: 10.1007/s10543-016-0618-9
W2Ito1: High Weak Order Method
Adaptive step weak order 2.0 for Ito SDEs (deterministic order 3).
Can handle diagonal, non-diagonal, non-commuting, and scalar additive noise.
"""
struct W2Ito1 <: StochasticDiffEqAdaptiveAlgorithm end

# Stratonovich sense

"""
Expand Down
164 changes: 164 additions & 0 deletions src/caches/srk_weak_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2318,3 +2318,167 @@ function alg_cache(alg::SMEB,prob,u,ΔW,ΔZ,p,rate_prototype,

SIESMECache(u,uprev,W2,W3,tab,k0,k1,g0,g1,g2,tmpu)
end





# Tang & Xiao: DOI 10.1007/s10543-016-0618-9 W2Ito1 and W2Ito2 methods

struct W2Ito1ConstantCache{T} <: StochasticDiffEqConstantCache
# hard-coded version
a021::T
a031::T
a032::T

a121::T
a131::T
#a132::T

#a221::T
#a231::T
#a232::T

b021::T
b031::T
#b032::T

b121::T
b131::T
#b132::T

b221::T
#b222::T
#b223::T
#b231::T
#b232::T
#b233::T

α1::T
α2::T
α3::T

beta01::T
beta02::T
beta03::T

beta11::T
#beta12::T
beta13::T

#quantile(Normal(),1/6)
NORMAL_ONESIX_QUANTILE::T
end



function W2Ito1ConstantCache(::Type{T}, ::Type{T2}) where {T,T2}

a021 = convert(T, 1 // 2)
a031 = convert(T, -1)
a032 = convert(T, 2)

a121 = convert(T, 1 // 4)
a131 = convert(T, 1 // 4)

b021 = convert(T, (6 - sqrt(6)) / 10)
b031 = convert(T, (3 + 2 * sqrt(6)) / 5)

b121 = convert(T, 1 // 2)
b131 = convert(T, -1 // 2)

b221 = convert(T, 1)

α1 = convert(T, 1 // 6)
α2 = convert(T, 2 // 3)
α3 = convert(T, 1 // 6)

beta01 = convert(T, -1)
beta02 = convert(T, 1)
beta03 = convert(T, 1)

beta11 = convert(T, 2)
beta13 = convert(T, -2)

NORMAL_ONESIX_QUANTILE = convert(T, -0.9674215661017014)

W2Ito1ConstantCache(a021, a031, a032, a121, a131, b021, b031, b121, b131, b221, α1, α2, α3, beta01, beta02, beta03, beta11, beta13, NORMAL_ONESIX_QUANTILE)
end


function alg_cache(alg::W2Ito1, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
W2Ito1ConstantCache(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
end

@cache struct W2Ito1Cache{uType,randType,tabType,rateNoiseType,rateType,possibleRateType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
uhat::uType

_dW::randType
_dZ::randType
chi1::randType

tab::tabType

g1::rateNoiseType
g2::rateNoiseType
g3::rateNoiseType

k1::rateType
k2::rateType
k3::rateType

H02::uType
H03::uType
H12::Vector{uType}
H13::Vector{uType}

tmp1::possibleRateType
tmpg::rateNoiseType

tmp::uType
resids::uType

end

function alg_cache(alg::W2Ito1, prob, u, ΔW, ΔZ, p, rate_prototype,
noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
if ΔW isa Union{SArray,Number}
_dW = copy(ΔW)
_dZ = zeros(eltype(ΔW), 2)
chi1 = copy(ΔW)
else
_dW = zero(ΔW)
_dZ = zeros(eltype(ΔW), 2)
chi1 = zero(ΔW)
end
m = length(ΔW)
tab = W2Ito1ConstantCache(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
g1 = zero(noise_rate_prototype)
g2 = zero(noise_rate_prototype)
g3 = zero(noise_rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)

H02 = zero(u)
H03 = zero(u)
H12 = Vector{typeof(u)}()
H13 = Vector{typeof(u)}()

for k = 1:m
push!(H12, zero(u))
push!(H13, zero(u))
end

tmp1 = zero(rate_prototype)
tmpg = zero(noise_rate_prototype)

uhat = copy(uprev)
tmp = zero(u)
resids = zero(u)

W2Ito1Cache(u, uprev, uhat, _dW, _dZ, chi1, tab, g1, g2, g3, k1, k2, k3, H02, H03, H12, H13, tmp1, tmpg, tmp, resids)
end
Loading

0 comments on commit 444328d

Please sign in to comment.