Skip to content

Commit

Permalink
Merge pull request #582 from prbzrg/trait2-1096
Browse files Browse the repository at this point in the history
`alg_interpretation` from SciMLBase
  • Loading branch information
ChrisRackauckas authored Aug 30, 2024
2 parents e02a78b + dbfbb77 commit 79552a4
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 38 deletions.
32 changes: 16 additions & 16 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,22 +133,22 @@ beta1_default(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm

isdtchangeable(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = true

alg_interpretation(alg::StochasticDiffEqAlgorithm) = :Ito
alg_interpretation(alg::EulerHeun) = :Stratonovich
alg_interpretation(alg::LambaEulerHeun) = :Stratonovich
alg_interpretation(alg::KomBurSROCK2) = :Stratonovich
alg_interpretation(alg::RKMil{interpretation}) where {interpretation} = interpretation
alg_interpretation(alg::SROCK1{interpretation,E}) where {interpretation,E} = interpretation
alg_interpretation(alg::RKMilCommute) = alg.interpretation
alg_interpretation(alg::RKMilGeneral) = alg.interpretation
alg_interpretation(alg::ImplicitRKMil{CS,AD,F,P,FDT,ST,CJ,N,T2,Controller,interpretation}) where {CS,AD,F,P,FDT,ST,CJ,N,T2,Controller,interpretation} = interpretation

alg_interpretation(alg::RS1) = :Stratonovich
alg_interpretation(alg::RS2) = :Stratonovich

alg_interpretation(alg::NON) = :Stratonovich
alg_interpretation(alg::COM) = :Stratonovich
alg_interpretation(alg::NON2) = :Stratonovich
SciMLBase.alg_interpretation(alg::StochasticDiffEqAlgorithm) = :Ito
SciMLBase.alg_interpretation(alg::EulerHeun) = :Stratonovich
SciMLBase.alg_interpretation(alg::LambaEulerHeun) = :Stratonovich
SciMLBase.alg_interpretation(alg::KomBurSROCK2) = :Stratonovich
SciMLBase.alg_interpretation(alg::RKMil{interpretation}) where {interpretation} = interpretation
SciMLBase.alg_interpretation(alg::SROCK1{interpretation,E}) where {interpretation,E} = interpretation
SciMLBase.alg_interpretation(alg::RKMilCommute) = alg.interpretation
SciMLBase.alg_interpretation(alg::RKMilGeneral) = alg.interpretation
SciMLBase.alg_interpretation(alg::ImplicitRKMil{CS,AD,F,P,FDT,ST,CJ,N,T2,Controller,interpretation}) where {CS,AD,F,P,FDT,ST,CJ,N,T2,Controller,interpretation} = interpretation

SciMLBase.alg_interpretation(alg::RS1) = :Stratonovich
SciMLBase.alg_interpretation(alg::RS2) = :Stratonovich

SciMLBase.alg_interpretation(alg::NON) = :Stratonovich
SciMLBase.alg_interpretation(alg::COM) = :Stratonovich
SciMLBase.alg_interpretation(alg::NON2) = :Stratonovich

alg_compatible(prob, alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = true
alg_compatible(prob, alg::StochasticDiffEqAlgorithm) = false
Expand Down
12 changes: 6 additions & 6 deletions src/perform_step/SROCK_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
cosh_inv = log(ω₀ + Sqrt_ω) # arcosh(ω₀)
ω₁ = (Sqrt_ω*cosh(mdeg*cosh_inv))/(mdeg*sinh(mdeg*cosh_inv))

if alg_interpretation(integrator.alg) == :Stratonovich
if SciMLBase.alg_interpretation(integrator.alg) == :Stratonovich
α = cosh(mdeg*cosh_inv)/(2*ω₀*cosh((mdeg-1)*cosh_inv))
γ = 1/(2*α)
β = -γ
Expand Down Expand Up @@ -42,7 +42,7 @@
k = integrator.f(uᵢ₋₁,p,tᵢ₋₁)

u = dt*μ*k + ν*uᵢ₋₁ + κ*uᵢ₋₂
if (i > mdeg - 2) && alg_interpretation(integrator.alg) == :Stratonovich
if (i > mdeg - 2) && SciMLBase.alg_interpretation(integrator.alg) == :Stratonovich
if i == mdeg - 1
gₘ₋₂ = integrator.g(uᵢ₋₁,p,tᵢ₋₁)
if W.dW isa Number || !is_diagonal_noise(integrator.sol.prob)
Expand All @@ -58,7 +58,7 @@
u .+=.* gₘ₋₂ .+ γ .* gₘ₋₁) .* W.dW
end
end
elseif (i == mdeg) && alg_interpretation(integrator.alg) == :Ito
elseif (i == mdeg) && SciMLBase.alg_interpretation(integrator.alg) == :Ito
if W.dW isa Number
gₘ₋₂ = integrator.g(uᵢ₋₁,p,tᵢ₋₁)
uᵢ₋₂ = uᵢ₋₁ + sqrt(abs(dt))*gₘ₋₂
Expand Down Expand Up @@ -105,7 +105,7 @@ end
cosh_inv = log(ω₀ + Sqrt_ω) # arcosh(ω₀)
ω₁ = (Sqrt_ω*cosh(mdeg*cosh_inv))/(mdeg*sinh(mdeg*cosh_inv))

if alg_interpretation(integrator.alg) == :Stratonovich
if SciMLBase.alg_interpretation(integrator.alg) == :Stratonovich
α = cosh(mdeg*cosh_inv)/(2*ω₀*cosh((mdeg-1)*cosh_inv))
γ = 1/(2*α)
β = -γ
Expand All @@ -132,7 +132,7 @@ end
κ = - Tᵢ₋₂/Tᵢ
integrator.f(k,uᵢ₋₁,p,tᵢ₋₁)
@.. u = dt*μ*k + ν*uᵢ₋₁ + κ*uᵢ₋₂
if (i > mdeg - 2) && alg_interpretation(integrator.alg) == :Stratonovich
if (i > mdeg - 2) && SciMLBase.alg_interpretation(integrator.alg) == :Stratonovich
if i == mdeg - 1
integrator.g(gₘ₋₂,uᵢ₋₁,p,tᵢ₋₁)
if W.dW isa Number || is_diagonal_noise(integrator.sol.prob)
Expand All @@ -152,7 +152,7 @@ end
@.. u += γ*k
end
end
elseif (i == mdeg) && alg_interpretation(integrator.alg) == :Ito
elseif (i == mdeg) && SciMLBase.alg_interpretation(integrator.alg) == :Ito
if W.dW isa Number || is_diagonal_noise(integrator.sol.prob)
integrator.g(gₘ₋₂,uᵢ₋₁,p,tᵢ₋₁)
@.. uᵢ₋₂ = uᵢ₋₁ + sqrt(abs(dt))*gₘ₋₂
Expand Down
24 changes: 12 additions & 12 deletions src/perform_step/low_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ end
K = @.. uprev + dt * du1
L = integrator.g(uprev,p,t)
mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
if SciMLBase.alg_interpretation(integrator.alg) == :Ito
utilde = K + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
mil_correction = ggprime.*(W.dW.^2 .- abs(dt))./2
elseif alg_interpretation(integrator.alg) == :Stratonovich
elseif SciMLBase.alg_interpretation(integrator.alg) == :Stratonovich
utilde = uprev + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
mil_correction = ggprime.*(W.dW.^2)./2
Expand Down Expand Up @@ -241,11 +241,11 @@ end
integrator.g(L,uprev,p,t)
@.. K = uprev + dt * du1
@.. du2 = zero(eltype(u)) # This makes it safe to re-use the array
if alg_interpretation(integrator.alg) == :Ito
if SciMLBase.alg_interpretation(integrator.alg) == :Ito
@.. tmp = K + integrator.sqdt * L
integrator.g(du2,tmp,p,t)
@.. tmp = (du2-L)/(2integrator.sqdt)*(W.dW.^2 - abs(dt))
elseif alg_interpretation(integrator.alg) == :Stratonovich
elseif SciMLBase.alg_interpretation(integrator.alg) == :Stratonovich
@.. tmp = uprev + integrator.sqdt * L
integrator.g(du2,tmp,p,t)
@.. tmp = (du2-L)/(2integrator.sqdt)*(W.dW.^2)
Expand Down Expand Up @@ -275,7 +275,7 @@ end
J = get_iterated_I(dt, dW, W.dZ, Jalg)

mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
if SciMLBase.alg_interpretation(integrator.alg) == :Ito
if dW isa Number || is_diagonal_noise(integrator.sol.prob)
J = J .- 1//2 .* abs(dt)
else
Expand All @@ -289,7 +289,7 @@ end
K = uprev + dt*du1

if is_diagonal_noise(integrator.sol.prob)
tmp = (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
tmp = (SciMLBase.alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
gtmp = integrator.g(tmp,p,t)
Dgj = (gtmp - L)/sqdt
ggprime_norm = integrator.opts.internalnorm(Dgj,t)
Expand Down Expand Up @@ -343,7 +343,7 @@ end
J = Jalg.J

@.. mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
if SciMLBase.alg_interpretation(integrator.alg) == :Ito
if dW isa Number || is_diagonal_noise(integrator.sol.prob)
@.. J -= 1 // 2 * abs(dt)
else
Expand All @@ -357,7 +357,7 @@ end
@.. K = uprev + dt*du1

if is_diagonal_noise(integrator.sol.prob)
tmp .= (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
tmp .= (SciMLBase.alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
integrator.g(gtmp,tmp,p,t)
@.. Dgj = (gtmp - L)/sqdt
ggprime_norm = integrator.opts.internalnorm(Dgj,t)
Expand Down Expand Up @@ -397,7 +397,7 @@ end

J = get_iterated_I(dt, dW, W.dZ, Jalg, integrator.alg.p, integrator.alg.c, alg_order(integrator.alg))

if alg_interpretation(integrator.alg) == :Ito
if SciMLBase.alg_interpretation(integrator.alg) == :Ito
if dW isa Number || is_diagonal_noise(integrator.sol.prob)
J = J .- 1//2 .* abs(dt)
else
Expand All @@ -413,7 +413,7 @@ end

if dW isa Number || is_diagonal_noise(integrator.sol.prob)
K = @.. uprev + dt*du₁
utilde = (alg_interpretation(integrator.alg) == :Ito ? K : uprev) + L*integrator.sqdt
utilde = (SciMLBase.alg_interpretation(integrator.alg) == :Ito ? K : uprev) + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t) .- L) ./ (integrator.sqdt)
mil_correction = ggprime .* J
u = K + L .* dW + mil_correction
Expand Down Expand Up @@ -467,7 +467,7 @@ end
@.. mil_correction = zero(eltype(u))
ggprime_norm = zero(eltype(ggprime))

if alg_interpretation(integrator.alg) == :Ito
if SciMLBase.alg_interpretation(integrator.alg) == :Ito
if dW isa Number || is_diagonal_noise(integrator.sol.prob)
@.. J -= 1 // 2 * abs(dt)
else
Expand All @@ -478,7 +478,7 @@ end
if dW isa Number || is_diagonal_noise(integrator.sol.prob)
@.. K = uprev + dt*du₁
@.. du₂ = zero(eltype(u))
tmp .= (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
tmp .= (SciMLBase.alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
integrator.g(du₂,tmp,p,t)
@.. ggprime = (du₂ - L)/sqdt
ggprime_norm = integrator.opts.internalnorm(ggprime,t)
Expand Down
8 changes: 4 additions & 4 deletions src/perform_step/sdirk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
end

if cache isa ImplicitRKMilConstantCache || integrator.opts.adaptive == true
if alg_interpretation(alg) == :Ito ||
if SciMLBase.alg_interpretation(alg) == :Ito ||
cache isa ImplicitEMConstantCache
K = @.. uprev + dt * ftmp
utilde = K + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
mil_correction = ggprime .* (integrator.W.dW.^2 .- abs(dt))./2
gtmp += mil_correction
elseif alg_interpretation(alg) == :Stratonovich ||
elseif SciMLBase.alg_interpretation(alg) == :Stratonovich ||
cache isa ImplicitEulerHeunConstantCache
utilde = uprev + L*integrator.sqdt
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
Expand Down Expand Up @@ -154,12 +154,12 @@ end

if cache isa ImplicitRKMilCache
gtmp3 = cache.gtmp3
if alg_interpretation(alg) == :Ito
if SciMLBase.alg_interpretation(alg) == :Ito
@.. z = uprev + dt * tmp + integrator.sqdt * gtmp
integrator.g(gtmp3,z,p,t)
@.. gtmp3 = (gtmp3-gtmp)/(integrator.sqdt) # ggprime approximation
@.. gtmp2 += gtmp3*(dW.^2 - abs(dt))/2
elseif alg_interpretation(alg) == :Stratonovich
elseif SciMLBase.alg_interpretation(alg) == :Stratonovich
@.. z = uprev + integrator.sqdt * gtmp
integrator.g(gtmp3,z,p,t)
@.. gtmp3 = (gtmp3-gtmp)/(integrator.sqdt) # ggprime approximation
Expand Down

0 comments on commit 79552a4

Please sign in to comment.