diff --git a/src/perform_step/srk_weak.jl b/src/perform_step/srk_weak.jl index aa6ab16ff..746b9e8d8 100644 --- a/src/perform_step/srk_weak.jl +++ b/src/perform_step/srk_weak.jl @@ -55,7 +55,7 @@ if typeof(W.dW) <: Number H12 = uprev + a121*k1*dt + b121*g1*integrator.sqdt elseif is_diagonal_noise(integrator.sol.prob) - H12 = Vector{typeof(uprev)}[uprev .+ a121*k1*dt .+ b121*g1[k]*integrator.sqdt for k=1:m] + H12 = [uprev .+ a121*k1*dt .+ b121*g1[k]*integrator.sqdt for k=1:m] else H12 = Vector{typeof(uprev)}[uprev .+ a121*k1*dt .+ b121*g1[:,k]*integrator.sqdt for k=1:m] end @@ -64,7 +64,7 @@ if typeof(W.dW) <: Number H13 = uprev + a131*k1*dt + b131*g1*integrator.sqdt elseif is_diagonal_noise(integrator.sol.prob) - H13 = Vector{typeof(uprev)}[uprev .+ a131*k1*dt .+ b131*g1[k]*integrator.sqdt for k=1:m] + H13 = [uprev .+ a131*k1*dt .+ b131*g1[k]*integrator.sqdt for k=1:m] else H13 = Vector{typeof(uprev)}[uprev .+ a131*k1*dt .+ b131*g1[:,k]*integrator.sqdt for k=1:m] end @@ -88,7 +88,11 @@ continue end if is_diagonal_noise(integrator.sol.prob) - @.. H22[k] += (b221*g1[l]+b222*g2[l][l]+b223*g3[l][l])*Ihat2[k,l]/integrator.sqdt + if typeof(W.dW) <: SArray + H22[k] = @.. H22[k] + (b221*g1[l]+b222*g2[l][l]+b223*g3[l][l])*Ihat2[k,l]/integrator.sqdt + else + @.. H22[k] += (b221*g1[l]+b222*g2[l][l]+b223*g3[l][l])*Ihat2[k,l]/integrator.sqdt + end else H22[k] += (b221*g1[:,l]+b222*g2[l][:,l]+b223*g3[l][:,l])*Ihat2[k,l]/integrator.sqdt end @@ -107,7 +111,11 @@ continue end if is_diagonal_noise(integrator.sol.prob) - @.. H23[k] += (b231*g1[l]+b232*g2[l][l]+b233*g3[l][l])*Ihat2[k,l]/integrator.sqdt + if typeof(W.dW) <: SArray + H23[k] = @.. H23[k]+ (b231*g1[l]+b232*g2[l][l]+b233*g3[l][l])*Ihat2[k,l]/integrator.sqdt + else + @.. H23[k] += (b231*g1[l]+b232*g2[l][l]+b233*g3[l][l])*Ihat2[k,l]/integrator.sqdt + end else H23[k] += (b231*g1[:,l]+b232*g2[l][:,l]+b233*g3[l][:,l])*Ihat2[k,l]/integrator.sqdt end @@ -127,13 +135,26 @@ else if is_diagonal_noise(integrator.sol.prob) u += g1.*_dW*beta11 - for k=1:m - u[k] += g2[k][k]*_dW[k]*beta12+g3[k][k]*_dW[k]*beta13 - u[k] += g2[k][k]*chi1[k]*beta22/integrator.sqdt + g3[k][k]*chi1[k]*beta23/integrator.sqdt - tmpg = integrator.g(H22[k],p,t) - u[k] = u[k] + g1[k]*_dW[k]*beta31 + tmpg[k]*(_dW[k]*beta32 + integrator.sqdt*beta42) - tmpg = integrator.g(H23[k],p,t) - u[k] = u[k] + tmpg[k]*(_dW[k]*beta33 + integrator.sqdt*beta43) + if typeof(W.dW) <: SArray + for k=1:m + tmp = u[k] + tmp += g2[k][k]*_dW[k]*beta12+g3[k][k]*_dW[k]*beta13 + tmp += g2[k][k]*chi1[k]*beta22/integrator.sqdt + g3[k][k]*chi1[k]*beta23/integrator.sqdt + tmpg = integrator.g(H22[k],p,t) + tmp = tmp + g1[k]*_dW[k]*beta31 + tmpg[k]*(_dW[k]*beta32 + integrator.sqdt*beta42) + tmpg = integrator.g(H23[k],p,t) + tmp = tmp + tmpg[k]*(_dW[k]*beta33 + integrator.sqdt*beta43) + u = setindex(u, tmp, k) + end + else + for k=1:m + u[k] += g2[k][k]*_dW[k]*beta12+g3[k][k]*_dW[k]*beta13 + u[k] += g2[k][k]*chi1[k]*beta22/integrator.sqdt + g3[k][k]*chi1[k]*beta23/integrator.sqdt + tmpg = integrator.g(H22[k],p,t) + u[k] = u[k] + g1[k]*_dW[k]*beta31 + tmpg[k]*(_dW[k]*beta32 + integrator.sqdt*beta42) + tmpg = integrator.g(H23[k],p,t) + u[k] = u[k] + tmpg[k]*(_dW[k]*beta33 + integrator.sqdt*beta43) + end end else # non-diag noise diff --git a/test/static_array_tests.jl b/test/static_array_tests.jl index 9bc56d2ea..8342b6d61 100644 --- a/test/static_array_tests.jl +++ b/test/static_array_tests.jl @@ -1,5 +1,5 @@ using StaticArrays, StochasticDiffEq - +using Test f(du,u,p,t) = (du .= u) u0 = zeros(MVector{2,Float64}, 2) .+ 1 @@ -14,14 +14,15 @@ ode = SDEProblem(f, f, u0, (0.,1.)) @test_broken sol = solve(ode, EM(), dt=1.e-2) @test_broken sol = solve(ode, SRIW1(), dt=1.e-2) - u0 = ones(MVector{2,Float64}) ode = SDEProblem(f, f, u0, (0.,1.)) sol = solve(ode, EM(), dt=1.e-2) sol = solve(ode, SRIW1()) +sol = solve(ode, DRI1(), dt=1.e-2) u0 = ones(SVector{2,Float64}) f(u,p,t) = u -prob = SDEProblem(f, f, u0, (0.,1.)) -sol = solve(ode, EM(), dt=1.e-2) -sol = solve(ode, SRIW1()) +prob = SDEProblem{false}(f, f, u0, (0.,1.)) +sol = solve(prob, EM(), dt=1.e-2) +sol = solve(prob, SRIW1()) +sol = solve(prob, DRI1(), dt=1.e-2)