diff --git a/src/perform_step/sra.jl b/src/perform_step/sra.jl index 6f12c86c..bb0a8d42 100644 --- a/src/perform_step/sra.jl +++ b/src/perform_step/sra.jl @@ -178,20 +178,34 @@ end g1 = integrator.g(uprev,p,t+c11*dt) k1 = integrator.f(uprev,p,t) - H01 = uprev + dt*a21*k1 + b21*chi2.*g1 + if is_diagonal_noise(integrator.sol.prob) + H01 = uprev + dt*a21*k1 + b21*chi2.*g1 + else + H01 = uprev + dt*a21*k1 + b21*g1*chi2 + end g2 = integrator.g(H01,p,t+c12*dt) k2 = integrator.f(H01,p,t+c02*dt) - H02 = uprev + dt*(a31*k1 + a32*k2) + chi2.*(b31*g1 + b32*g2) + if is_diagonal_noise(integrator.sol.prob) + H02 = uprev + dt*(a31*k1 + a32*k2) + chi2.*(b31*g1 + b32*g2) + else + H02 = uprev + dt*(a31*k1 + a32*k2) + (b31*g1 + b32*g2)*chi2 + end + g3 = integrator.g(H02,p,t+c13*dt) k3 = integrator.f(H02,p,t+c03*dt) E₁ = dt*(α1*k1 + α2*k2 + α3*k3) - E₂ = chi2.*(beta21*g1 + beta22*g2 + beta23*g3) - u = uprev + E₁ + E₂ + W.dW.*(beta11*g1 + beta12*g2 + beta13*g3) + if is_diagonal_noise(integrator.sol.prob) + E₂ = chi2.*(beta21*g1 + beta22*g2 + beta23*g3) + u = uprev + E₁ + E₂ + W.dW.*(beta11*g1 + beta12*g2 + beta13*g3) + else + E₂ = (beta21*g1 + beta22*g2 + beta23*g3)*chi2 + u = uprev + E₁ + E₂ + (beta11*g1 + beta12*g2 + beta13*g3)*W.dW + end if integrator.alg isa StochasticCompositeAlgorithm && integrator.alg.algs[1] isa SOSRA2 ϱu = integrator.opts.internalnorm(k3 - k2, t) @@ -255,7 +269,7 @@ end mul!(E₂,gtmp,chi2) @.. gtmp = beta11*g1 + beta12*g2 + beta13*g3 mul!(E₁,gtmp,W.dW) - u = uprev + dt*(α1*k1 + α2*k2 + α3*k3) + E₂ + E₁ + @.. u = uprev + dt*(α1*k1 + α2*k2 + α3*k3) + E₂ + E₁ end if integrator.alg isa StochasticCompositeAlgorithm && integrator.alg.algs[1] isa SOSRA2 diff --git a/test/nondiagonal_tests.jl b/test/nondiagonal_tests.jl index 7bfa6374..7f2da112 100644 --- a/test/nondiagonal_tests.jl +++ b/test/nondiagonal_tests.jl @@ -108,8 +108,48 @@ prob = SDEProblem(f_morenoise,g_morenoise,ones(2),(0.0,1.0), noise_rate_prototype=zeros(2,4)) sol =solve(prob,dt=1/2^(3),EM()) +sol =solve(prob,dt=1/2^(3),SRA()) sol =solve(prob,dt=1/2^(3),ISSEM()) sol =solve(prob,dt=1/2^(3),ImplicitEM()) sol =solve(prob,dt=1/2^(3),EulerHeun()) sol =solve(prob,dt=1/2^(3),ImplicitEulerHeun()) sol =solve(prob,dt=1/2^(3),ISSEulerHeun()) + + +f(du, u, p, t) = du .= u +function g(du, u, p, t) + du .= [-0.80 -0.3; -0.8 0.3] +end + +u0 = ones(2) +dt = 1//2^(4) +tspan = (0., 1.) + +prototype = zeros(2,2) + +iip_prob = SDEProblem{true}(f, g, u0, tspan, noise_rate_prototype = prototype) +@test !(solve(iip_prob, EM(), dt = 0.1)[end] ≈ ones(2)) +@test !(solve(iip_prob, SOSRA())[end] ≈ ones(2)) + +# Out of place regression tests + +f(u, p, t) = u +function g(u, p, t) + return [-0.80 -0.3; -0.8 0.3] +end + +u0 = ones(2) +dt = 1//2^(4) +tspan = (0., 1.) + +prototype = zeros(2,2) + +oop_prob = SDEProblem{false}(f, g, u0, tspan, noise_rate_prototype = prototype) +oop_sol = solve(oop_prob, EM(), dt = dt) +oop_sol = solve(oop_prob, SOSRA()) +sol =solve(oop_prob,dt=1/2^(3),EM()) +sol =solve(oop_prob,dt=1/2^(3),ISSEM()) +@test_broken sol =solve(oop_prob,dt=1/2^(3),ImplicitEM()) +sol =solve(oop_prob,dt=1/2^(3),EulerHeun()) +@test_broken sol =solve(oop_prob,dt=1/2^(3),ImplicitEulerHeun()) +sol =solve(oop_prob,dt=1/2^(3),ISSEulerHeun()) \ No newline at end of file