Skip to content

Commit

Permalink
Merge pull request #588 from SciML/nondiagonal_oop
Browse files Browse the repository at this point in the history
Handle non-diagonal noise in out of place SRA
  • Loading branch information
ChrisRackauckas authored Oct 29, 2024
2 parents 2930ba8 + f260567 commit 07bd47f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/perform_step/sra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,34 @@ end
g1 = integrator.g(uprev,p,t+c11*dt)
k1 = integrator.f(uprev,p,t)

H01 = uprev + dt*a21*k1 + b21*chi2.*g1
if is_diagonal_noise(integrator.sol.prob)
H01 = uprev + dt*a21*k1 + b21*chi2.*g1
else
H01 = uprev + dt*a21*k1 + b21*g1*chi2
end

g2 = integrator.g(H01,p,t+c12*dt)
k2 = integrator.f(H01,p,t+c02*dt)

H02 = uprev + dt*(a31*k1 + a32*k2) + chi2.*(b31*g1 + b32*g2)
if is_diagonal_noise(integrator.sol.prob)
H02 = uprev + dt*(a31*k1 + a32*k2) + chi2.*(b31*g1 + b32*g2)
else
H02 = uprev + dt*(a31*k1 + a32*k2) + (b31*g1 + b32*g2)*chi2
end


g3 = integrator.g(H02,p,t+c13*dt)
k3 = integrator.f(H02,p,t+c03*dt)

E₁ = dt*(α1*k1 + α2*k2 + α3*k3)
E₂ = chi2.*(beta21*g1 + beta22*g2 + beta23*g3)

u = uprev + E₁ + E₂ + W.dW.*(beta11*g1 + beta12*g2 + beta13*g3)
if is_diagonal_noise(integrator.sol.prob)
E₂ = chi2.*(beta21*g1 + beta22*g2 + beta23*g3)
u = uprev + E₁ + E₂ + W.dW.*(beta11*g1 + beta12*g2 + beta13*g3)
else
E₂ = (beta21*g1 + beta22*g2 + beta23*g3)*chi2
u = uprev + E₁ + E₂ + (beta11*g1 + beta12*g2 + beta13*g3)*W.dW
end

if integrator.alg isa StochasticCompositeAlgorithm && integrator.alg.algs[1] isa SOSRA2
ϱu = integrator.opts.internalnorm(k3 - k2, t)
Expand Down Expand Up @@ -255,7 +269,7 @@ end
mul!(E₂,gtmp,chi2)
@.. gtmp = beta11*g1 + beta12*g2 + beta13*g3
mul!(E₁,gtmp,W.dW)
u = uprev + dt*(α1*k1 + α2*k2 + α3*k3) + E₂ + E₁
@.. u = uprev + dt*(α1*k1 + α2*k2 + α3*k3) + E₂ + E₁
end

if integrator.alg isa StochasticCompositeAlgorithm && integrator.alg.algs[1] isa SOSRA2
Expand Down
41 changes: 41 additions & 0 deletions test/nondiagonal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,49 @@ prob = SDEProblem(f_morenoise,g_morenoise,ones(2),(0.0,1.0),
noise_rate_prototype=zeros(2,4))

sol =solve(prob,dt=1/2^(3),EM())
sol =solve(prob,dt=1/2^(3),SRA())
sol =solve(prob,dt=1/2^(3),ISSEM())
sol =solve(prob,dt=1/2^(3),ImplicitEM())
sol =solve(prob,dt=1/2^(3),EulerHeun())
sol =solve(prob,dt=1/2^(3),ImplicitEulerHeun())
sol =solve(prob,dt=1/2^(3),ISSEulerHeun())


f(du, u, p, t) = du .= u
function g(du, u, p, t)
du .= [-0.80 -0.3; -0.8 0.3]
end

u0 = ones(2)
dt = 1//2^(4)
tspan = (0., 1.)

prototype = zeros(2,2)

iip_prob = SDEProblem{true}(f, g, u0, tspan, noise_rate_prototype = prototype)
@test !(solve(iip_prob, EM(), dt = 0.1)[end] ones(2))
@test !(solve(iip_prob, SOSRA())[end] ones(2))
@test !(solve(iip_prob, SRA3())[end] ones(2))

# Out of place regression tests

f(u, p, t) = u
function g(u, p, t)
return [-0.80 -0.3; -0.8 0.3]
end

u0 = ones(2)
dt = 1//2^(4)
tspan = (0., 1.)

prototype = zeros(2,2)

oop_prob = SDEProblem{false}(f, g, u0, tspan, noise_rate_prototype = prototype)
oop_sol = solve(oop_prob, EM(), dt = dt)
oop_sol = solve(oop_prob, SOSRA())
sol =solve(oop_prob,dt=1/2^(3),EM())
sol =solve(oop_prob,dt=1/2^(3),ISSEM())
@test_broken sol =solve(oop_prob,dt=1/2^(3),ImplicitEM())
sol =solve(oop_prob,dt=1/2^(3),EulerHeun())
@test_broken sol =solve(oop_prob,dt=1/2^(3),ImplicitEulerHeun())
sol =solve(oop_prob,dt=1/2^(3),ISSEulerHeun())

0 comments on commit 07bd47f

Please sign in to comment.