Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

W2Ito1 scheme #571

Merged
merged 9 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ end
SROCK1, SROCK2, SROCKEM, SKSROCK, TangXiaoSROCK2, KomBurSROCK2, SROCKC2,
WangLi3SMil_A, WangLi3SMil_B, WangLi3SMil_C, WangLi3SMil_D, WangLi3SMil_E, WangLi3SMil_F,
AutoSOSRI2, AutoSOSRA2,
DRI1, DRI1NM, RI1, RI3, RI5, RI6, RDI1WM, RDI2WM, RDI3WM, RDI4WM,
DRI1, DRI1NM, RI1, RI3, RI5, RI6, RDI1WM, RDI2WM, RDI3WM, RDI4WM, W2Ito1,
RS1, RS2,
PL1WM, PL1WMA,
NON, COM, NON2
Expand Down
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ alg_order(alg::RDI1WM) = 1 // 1
alg_order(alg::RDI2WM) = 1 // 1
alg_order(alg::RDI3WM) = 1 // 1
alg_order(alg::RDI4WM) = 1 // 1
alg_order(alg::W2Ito1) = 1 // 1

alg_order(alg::RS1) = 1 // 1
alg_order(alg::RS2) = 1 // 1
Expand Down Expand Up @@ -181,6 +182,7 @@ alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI1WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI2WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI3WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RDI4WM) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::W2Ito1) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RS1) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::RS2) = true
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::PL1WM) = true
Expand Down Expand Up @@ -254,6 +256,7 @@ alg_needs_extra_process(alg::RDI1WM) = true
alg_needs_extra_process(alg::RDI2WM) = true
alg_needs_extra_process(alg::RDI3WM) = true
alg_needs_extra_process(alg::RDI4WM) = true
alg_needs_extra_process(alg::W2Ito1) = true
alg_needs_extra_process(alg::RS1) = true
alg_needs_extra_process(alg::RS2) = true
alg_needs_extra_process(alg::PL1WM) = true
Expand Down
13 changes: 13 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,19 @@ Can handle diagonal, non-diagonal, non-commuting, and scalar additive noise.
"""
struct RDI4WM <: StochasticDiffEqAdaptiveAlgorithm end


"""
Tang, X., & Xiao, A., Efficient weak second-order stochastic Runge–Kutta methods
for Itô stochastic differential equations,
BIT Numerical Mathematics, 57, 241-260 (2017)
DOI: 10.1007/s10543-016-0618-9

W2Ito1: High Weak Order Method
Adaptive step weak order 2.0 for Ito SDEs (deterministic order 3).
Can handle diagonal, non-diagonal, non-commuting, and scalar additive noise.
"""
struct W2Ito1 <: StochasticDiffEqAdaptiveAlgorithm end

# Stratonovich sense

"""
Expand Down
164 changes: 164 additions & 0 deletions src/caches/srk_weak_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2318,3 +2318,167 @@ function alg_cache(alg::SMEB,prob,u,ΔW,ΔZ,p,rate_prototype,

SIESMECache(u,uprev,W2,W3,tab,k0,k1,g0,g1,g2,tmpu)
end





# Tang & Xiao: DOI 10.1007/s10543-016-0618-9 W2Ito1 and W2Ito2 methods

struct W2Ito1ConstantCache{T} <: StochasticDiffEqConstantCache
# hard-coded version
a021::T
a031::T
a032::T

a121::T
a131::T
#a132::T

#a221::T
#a231::T
#a232::T

b021::T
b031::T
#b032::T

b121::T
b131::T
#b132::T

b221::T
#b222::T
#b223::T
#b231::T
#b232::T
#b233::T

α1::T
α2::T
α3::T

beta01::T
beta02::T
beta03::T

beta11::T
#beta12::T
beta13::T

#quantile(Normal(),1/6)
NORMAL_ONESIX_QUANTILE::T
end



function W2Ito1ConstantCache(::Type{T}, ::Type{T2}) where {T,T2}

a021 = convert(T, 1 // 2)
a031 = convert(T, -1)
a032 = convert(T, 2)

a121 = convert(T, 1 // 4)
a131 = convert(T, 1 // 4)

b021 = convert(T, (6 - sqrt(6)) / 10)
b031 = convert(T, (3 + 2 * sqrt(6)) / 5)

b121 = convert(T, 1 // 2)
b131 = convert(T, -1 // 2)

b221 = convert(T, 1)

α1 = convert(T, 1 // 6)
α2 = convert(T, 2 // 3)
α3 = convert(T, 1 // 6)

beta01 = convert(T, -1)
beta02 = convert(T, 1)
beta03 = convert(T, 1)

beta11 = convert(T, 2)
beta13 = convert(T, -2)

NORMAL_ONESIX_QUANTILE = convert(T, -0.9674215661017014)

W2Ito1ConstantCache(a021, a031, a032, a121, a131, b021, b031, b121, b131, b221, α1, α2, α3, beta01, beta02, beta03, beta11, beta13, NORMAL_ONESIX_QUANTILE)
end


function alg_cache(alg::W2Ito1, prob, u, ΔW, ΔZ, p, rate_prototype, noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits}, ::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
W2Ito1ConstantCache(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
end

@cache struct W2Ito1Cache{uType,randType,tabType,rateNoiseType,rateType,possibleRateType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
uhat::uType

_dW::randType
_dZ::randType
chi1::randType

tab::tabType

g1::rateNoiseType
g2::rateNoiseType
g3::rateNoiseType

k1::rateType
k2::rateType
k3::rateType

H02::uType
H03::uType
H12::Vector{uType}
H13::Vector{uType}

tmp1::possibleRateType
tmpg::rateNoiseType

tmp::uType
resids::uType

end

function alg_cache(alg::W2Ito1, prob, u, ΔW, ΔZ, p, rate_prototype,
noise_rate_prototype, jump_rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, f, t, dt, ::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
if ΔW isa Union{SArray,Number}
_dW = copy(ΔW)
_dZ = zeros(eltype(ΔW), 2)
chi1 = copy(ΔW)
else
_dW = zero(ΔW)
_dZ = zeros(eltype(ΔW), 2)
chi1 = zero(ΔW)
end
m = length(ΔW)
tab = W2Ito1ConstantCache(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
g1 = zero(noise_rate_prototype)
g2 = zero(noise_rate_prototype)
g3 = zero(noise_rate_prototype)
k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
k3 = zero(rate_prototype)

H02 = zero(u)
H03 = zero(u)
H12 = Vector{typeof(u)}()
H13 = Vector{typeof(u)}()

for k = 1:m
push!(H12, zero(u))
push!(H13, zero(u))
end

tmp1 = zero(rate_prototype)
tmpg = zero(noise_rate_prototype)

uhat = copy(uprev)
tmp = zero(u)
resids = zero(u)

W2Ito1Cache(u, uprev, uhat, _dW, _dZ, chi1, tab, g1, g2, g3, k1, k2, k3, H02, H03, H12, H13, tmp1, tmpg, tmp, resids)
end
Loading
Loading