Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RKMilCommute constant cache perform step #383

Merged
merged 14 commits into from
Jan 19, 2021
5 changes: 2 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ steps:
# codecov: true
agents:
queue: "juliagpu"
cuda: "*"
fastcpu: "true"
env:
GROUP: 'WeakConvergence'
timeout_in_minutes: 120
timeout_in_minutes: 480
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

Expand All @@ -29,7 +28,7 @@ steps:
fastcpu: "true"
env:
GROUP: 'WeakAdaptive'
timeout_in_minutes: 120
timeout_in_minutes: 240
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ jobs:
- WeakConvergence2
- WeakConvergence3
- WeakConvergence4
- WeakConvergence5
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
14 changes: 8 additions & 6 deletions src/caches/basic_method_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ function alg_cache(alg::RKMil,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
RKMilCache(u,uprev,du1,du2,K,tmp,L)
end

struct RKMilCommuteConstantCache <: StochasticDiffEqConstantCache end
struct RKMilCommuteConstantCache{WikType} <: StochasticDiffEqConstantCache
WikJ::WikType
end
@cache struct RKMilCommuteCache{uType,rateType,rateNoiseType,WikType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
Expand All @@ -125,24 +127,24 @@ struct RKMilCommuteConstantCache <: StochasticDiffEqConstantCache end
gtmp::rateNoiseType
L::rateNoiseType
WikJ::WikType
Dg::WikType
mil_correction::rateType
Kj::uType
Dgj::rateNoiseType
tmp::uType
end

alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{false}}) = RKMilCommuteConstantCache()
function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{false}}) WikJ = WikJ = get_WikJ(ΔW,prob,alg)
RKMilCommuteConstantCache{typeof(WikJ)}(WikJ)
end

function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{true}})
du1 = zero(rate_prototype); du2 = zero(rate_prototype)
K = zero(rate_prototype); gtmp = zero(noise_rate_prototype);
L = zero(noise_rate_prototype); tmp = zero(rate_prototype)
WikJ = false .* vec(ΔW) .* vec(ΔW)'
Dg = false .* vec(ΔW) .* vec(ΔW)'
WikJ = get_WikJ(ΔW,prob,alg)
mil_correction = zero(rate_prototype)
Kj = zero(u); Dgj = zero(noise_rate_prototype)
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,WikJ,Dg,mil_correction,Kj,Dgj,tmp)
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,WikJ,mil_correction,Kj,Dgj,tmp)
end

struct RKMilGeneralConstantCache{WikType} <: StochasticDiffEqConstantCache
Expand Down
18 changes: 17 additions & 1 deletion src/iterated_integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,23 @@ function get_WikJ(ΔW,prob,alg)
elseif alg.ii_approx isa IICommutative
return WikJCommute_oop()
else
return KPWJ_oop #WikJGeneral_oop(ΔW)
return KPWJ_oop() #WikJGeneral_oop(ΔW)
end
end
end

function get_WikJ(ΔW,prob,alg::RKMilCommute)
if isinplace(prob)
if typeof(ΔW) <: Number || is_diagonal_noise(prob)
return WikJDiagonal_iip(ΔW)
else
return WikJCommute_iip(ΔW)
end
else
if typeof(ΔW) <: Number || is_diagonal_noise(prob)
return WikJDiagonal_oop()
else
return WikJCommute_oop()
end
end
end
Expand Down
107 changes: 93 additions & 14 deletions src/perform_step/low_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,39 +246,118 @@ end
end
end

@muladd function perform_step!(integrator,cache::RKMilCommuteConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,W,p = integrator
dW = W.dW; sqdt = integrator.sqdt
Wik = cache.WikJ

ggprime_norm = 0.0

WikJ = get_iterated_I(dt, dW, W.dZ, Wik)

mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
if typeof(dW) <: Number || is_diagonal_noise(integrator.sol.prob)
WikJ = WikJ .- 1//2 .* dt
else
WikJ -= 1//2 .* UniformScaling(dt)
end
end

du1 = integrator.f(uprev,p,t)
L = integrator.g(uprev,p,t)

K = uprev + dt*du1

if is_diagonal_noise(integrator.sol.prob)
tmp = (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
gtmp = integrator.g(tmp,p,t)
Dgj = (gtmp - L)/sqdt
ggprime_norm = integrator.opts.internalnorm(Dgj,t)
u = @.. K + L*dW + Dgj*WikJ
else
for j = 1:length(dW)
if typeof(dW) <: Number
Kj = K + sqdt*L
else
Kj = K + sqdt*@view(L[:,j])
end
gtmp = integrator.g(Kj,p,t)
Dgj = (gtmp - L)/sqdt
if integrator.opts.adaptive
ggprime_norm += integrator.opts.internalnorm(Dgj,t)
end
if typeof(dW) <: Number
tmp = Dgj*WikJ
else
tmp = Dgj*@view(WikJ[:,j])
end
mil_correction += tmp
end
tmp = L*dW
u = uprev + dt*du1 + tmp + mil_correction
end
if integrator.opts.adaptive
En = integrator.opts.internalnorm(dW,t)^3*ggprime_norm^2 / 6
du2 = integrator.f(K,p,t+dt)
tmp = integrator.opts.internalnorm(integrator.opts.delta * dt * (du2 - du1) / 2,t) + En

tmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(tmp,t)

end
integrator.u = u
end

@muladd function perform_step!(integrator,cache::RKMilCommuteCache,f=integrator.f)
@unpack du1,du2,K,gtmp,L = cache
@unpack t,dt,uprev,u,W,p = integrator
@unpack WikJ,mil_correction,Kj,Dgj,tmp = cache
dW = W.dW; sqdt = integrator.sqdt
f = integrator.f; g = integrator.g

ggprime_norm = 0.0

Wik = cache.WikJ

get_iterated_I!(dt, dW, W.dZ, Wik)
WikJ = Wik.WikJ

@.. mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
WikJ .= 0.5 .* (vec(dW) .* vec(dW)' .- dt .* Eye{eltype(W.dW)}(length(W.dW)))
else
WikJ .= 0.5 .* vec(dW) .* vec(dW)'
if typeof(dW) <: Number || is_diagonal_noise(integrator.sol.prob)
@.. WikJ -= 1//2*dt
else
WikJ -= 1//2 .* UniformScaling(dt)
end
end

integrator.f(du1,uprev,p,t)
integrator.g(L,uprev,p,t)

@.. K = uprev + dt*du1
for j = 1:length(dW)
@.. Kj = K + sqdt*@view(L[:,j]) # This works too
#Kj .= uprev .+ sqdt*L[:,j]
g(gtmp,Kj,p,t)

if is_diagonal_noise(integrator.sol.prob)
tmp .= (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
integrator.g(gtmp,tmp,p,t)
@.. Dgj = (gtmp - L)/sqdt
if integrator.opts.adaptive
ggprime_norm += integrator.opts.internalnorm(Dgj,t)
ggprime_norm = integrator.opts.internalnorm(Dgj,t)
@.. u = K + L*dW + Dgj*WikJ
else
for j = 1:length(dW)
@.. Kj = K + sqdt*@view(L[:,j]) # This works too
#Kj .= uprev .+ sqdt*L[:,j]
integrator.g(gtmp,Kj,p,t)
@.. Dgj = (gtmp - L)/sqdt
if integrator.opts.adaptive
ggprime_norm += integrator.opts.internalnorm(Dgj,t)
end
mul!(tmp,Dgj,@view(WikJ[:,j]))
mil_correction .+= tmp
end
mul!(tmp,Dgj,@view(WikJ[:,j]))
mil_correction .+= tmp
mul!(tmp,L,dW)
@.. u .= uprev + dt*du1 + tmp + mil_correction
end
mul!(tmp,L,dW)
@.. u .= uprev + dt*du1 + tmp + mil_correction

if integrator.opts.adaptive
En = integrator.opts.internalnorm(W.dW,t)^3*ggprime_norm^2 / 6
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ function DiffEqBase.__init(
W = WienerProcess(t,rand_prototype,rand_prototype2,
save_everystep=save_noise,
rng = Xorshifts.Xoroshiro128Plus(_seed))
elseif alg===RKMilGeneral()
elseif typeof(alg) <: RKMilGeneral
m = length(rand_prototype)
if typeof(rand_prototype) <: Number || alg.p == nothing
rand_prototype2 = nothing
Expand Down
35 changes: 35 additions & 0 deletions test/commutative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ function f_commute(du,u,p,t)
du .+= 1.01u
end

function f_commute_oop(u,p,t)
du = A*u
du += 1.01u
end

function f_commute_analytic(u0,p,t,W)
tmp = (A+1.01I-2*(B^2))*t + B*sum(W)
exp(tmp)*u0
Expand All @@ -30,6 +35,13 @@ function σ(du,u,p,t)
du[2,4] = σ_const*u[2]
end

function σ_oop(u,p,t)
σ_const*[
u[1] u[1] u[1] u[1]
u[2] u[2] u[2] u[2]
]
end

ff_commute = SDEFunction(f_commute,σ,analytic=f_commute_analytic)

prob = SDEProblem(ff_commute,σ,u0,(0.0,1.0),noise_rate_prototype=rand(2,4))
Expand All @@ -51,3 +63,26 @@ sim2 = test_convergence(dts,prob,RKMilGeneral(ii_approx=IICommutative()),traject

sim2 = test_convergence(dts,prob,RKMilGeneral(p=2),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2


ff_commute_oop = SDEFunction(f_commute_oop,σ_oop,analytic=f_commute_analytic)

proboop = SDEProblem(ff_commute_oop,σ_oop,u0,(0.0,1.0),noise_rate_prototype=rand(2,4))

sol = solve(proboop,RKMilCommute(),dt=1/2^(8))
sol = solve(proboop,RKMilGeneral(ii_approx=IICommutative()),dt=1/2^(8))
sol = solve(proboop,RKMilGeneral(p=2),dt=1/2^(10))
sol = solve(proboop,EM(),dt=1/2^(10))

dts = (1/2) .^ (10:-1:3) #14->7 good plot
sim2 = test_convergence(dts,proboop,EM(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:final] - 0.5) < 0.2

sim2 = test_convergence(dts,proboop,RKMilCommute(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2

sim2 = test_convergence(dts,proboop,RKMilGeneral(ii_approx=IICommutative()),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2

sim2 = test_convergence(dts,proboop,RKMilGeneral(p=2),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2
2 changes: 1 addition & 1 deletion test/gpu/sde_weak_brusselator_adaptive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ensembleprob = EnsembleProblem(prob, prob_func = prob_func)
#Performance check with nvvp
# CUDAnative.CUDAdrv.@profile
# check either on CPU with EnsembleCPUArray() or on GPU with EnsembleGPUArray()
sol = @time solve(ensembleprob,DRI1(),EnsembleCPUArray(),trajectories=numtraj)
@test_nowarn sol = @time solve(ensembleprob,DRI1(),EnsembleCPUArray(),trajectories=numtraj)
#sol = @time solve(ensembleprob,DRI1(),EnsembleGPUArray(),trajectories=numtraj)


Expand Down
19 changes: 10 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,31 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence1")
@time @safetestset "OOP Weak Convergence Tests" begin include("weak_convergence/oop_weak.jl") end
@time @safetestset "Additive Weak Convergence Tests" begin include("weak_convergence/additive_weak.jl") end
@time @safetestset "IIP Weak Convergence Tests" begin include("weak_convergence/iip_weak.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence2")
@time @safetestset "Multidimensional IIP Weak Convergence Tests" begin include("weak_convergence/multidim_iip_weak.jl") end
@time @safetestset "Platen's PL1WM weak second order" begin include("weak_convergence/PL1WM.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence3")
if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence2")
@time @safetestset "Roessler weak SRK Tests" begin include("weak_convergence/srk_weak_final.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence4")
if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence3")
@time @safetestset "Roessler weak SRK (non-diagonal) Tests" begin include("weak_convergence/srk_weak_final_non_diagonal.jl") end
@time @safetestset "Weak Stratonovich Tests" begin include("weak_convergence/weak_strat.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence5")
if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence4")
@time @safetestset "Weak Stratonovich (non-diagonal) Tests" begin include("weak_convergence/weak_strat_non_diagonal.jl") end
@time @safetestset "SIE SME weak Tests" begin include("weak_convergence/SIE_SME.jl") end
end

if !is_APPVEYOR && GROUP == "WeakConvergence"
#activate_gpu_env()
@time @safetestset "OOP Weak Convergence Tests" begin include("weak_convergence/oop_weak.jl") end
@time @safetestset "Additive Weak Convergence Tests" begin include("weak_convergence/additive_weak.jl") end
@time @safetestset "IIP Weak Convergence Tests" begin include("weak_convergence/iip_weak.jl") end
end

if !is_APPVEYOR && GROUP == "WeakAdaptive"
activate_gpu_env()
@time @safetestset "Weak adaptive step size Brusselator " begin include("gpu/sde_weak_brusselator_adaptive.jl") end
Expand Down
6 changes: 6 additions & 0 deletions test/sde/sde_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ sim = test_convergence(dts,prob,LambaEM(),trajectories=Int(1e2))
@test abs(sim.𝒪est[:l2]-.5) < 0.2
sim2 = test_convergence(dts,prob,RKMil(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilCommute(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilGeneral(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
print(".")
Expand Down Expand Up @@ -101,6 +103,8 @@ sim = test_convergence(dts,prob,ImplicitRKMil(),trajectories=Int(1e2))
@test abs(sim.𝒪est[:l2]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMil(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.22
sim2 = test_convergence(dts,prob,RKMilCommute(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.22
sim2 = test_convergence(dts,prob,RKMilGeneral(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.22
print(".")
Expand Down Expand Up @@ -169,6 +173,8 @@ sim = test_convergence(dts,prob,ImplicitRKMil(),trajectories=Int(1e2))
@test abs(sim.𝒪est[:l2]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMil(),trajectories=Int(1e1))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilCommute(),trajectories=Int(1e1))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilGeneral(),trajectories=Int(1e1))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
print(".")
Expand Down
5 changes: 5 additions & 0 deletions test/sde/sde_linear_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ prob = prob_sde_linear
println("Solve and Plot")
sol = solve(prob,EM(),dt=1//2^(4))
sol = solve(prob,RKMil(),dt=1//2^(4))
sol = solve(prob,RKMilCommute(),dt=1//2^(4))
sol = solve(prob,RKMilGeneral(),dt=1//2^(4))
sol = solve(prob,SRI(),dt=1//2^(4))
sol = solve(prob,SRIW1(),dt=1//2^(4))
Expand All @@ -22,6 +23,8 @@ sim2 = test_convergence(dts,prob,RKMil(),trajectories=trajectories)

sim21 = test_convergence(dts,prob,RKMilGeneral(),trajectories=trajectories)

sim22 = test_convergence(dts,prob,RKMilCommute(),trajectories=trajectories)

sim3 = test_convergence(dts,prob,SRI(),trajectories=trajectories)

#TEST_PLOT && plot(plot(sim),plot(sim2),plot(sim3),layout=@layout([a b c]),size=(1200,600))
Expand All @@ -30,6 +33,8 @@ sim3 = test_convergence(dts,prob,SRI(),trajectories=trajectories)

@test abs(sim.𝒪est[:l2]-.5) + abs(sim21.𝒪est[:l∞]-1) + abs(sim3.𝒪est[:final]-1.5)<.441 #High tolerance since low dts for testing!

@test abs(sim22.𝒪est[:l∞]-1)<.3

# test reinit
integrator = init(prob,EM(),dt=1//2^(4))
solve!(integrator)
Expand Down
Loading