Skip to content

Commit

Permalink
Merge pull request #383 from frankschae/RKMilCommute
Browse files Browse the repository at this point in the history
RKMilCommute constant cache perform step
  • Loading branch information
ChrisRackauckas authored Jan 19, 2021
2 parents 1c13a69 + 82a42a3 commit 327bf80
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 57 deletions.
5 changes: 2 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ steps:
# codecov: true
agents:
queue: "juliagpu"
cuda: "*"
fastcpu: "true"
env:
GROUP: 'WeakConvergence'
timeout_in_minutes: 120
timeout_in_minutes: 480
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

Expand All @@ -29,7 +28,7 @@ steps:
fastcpu: "true"
env:
GROUP: 'WeakAdaptive'
timeout_in_minutes: 120
timeout_in_minutes: 240
# Don't run Buildkite if the commit message includes the text [skip tests]
if: build.message !~ /\[skip tests\]/

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ jobs:
- WeakConvergence2
- WeakConvergence3
- WeakConvergence4
- WeakConvergence5
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
14 changes: 8 additions & 6 deletions src/caches/basic_method_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ function alg_cache(alg::RKMil,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
RKMilCache(u,uprev,du1,du2,K,tmp,L)
end

struct RKMilCommuteConstantCache <: StochasticDiffEqConstantCache end
struct RKMilCommuteConstantCache{WikType} <: StochasticDiffEqConstantCache
WikJ::WikType
end
@cache struct RKMilCommuteCache{uType,rateType,rateNoiseType,WikType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
Expand All @@ -125,24 +127,24 @@ struct RKMilCommuteConstantCache <: StochasticDiffEqConstantCache end
gtmp::rateNoiseType
L::rateNoiseType
WikJ::WikType
Dg::WikType
mil_correction::rateType
Kj::uType
Dgj::rateNoiseType
tmp::uType
end

alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{false}}) = RKMilCommuteConstantCache()
function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{false}}) WikJ = WikJ = get_WikJ(ΔW,prob,alg)
RKMilCommuteConstantCache{typeof(WikJ)}(WikJ)
end

function alg_cache(alg::RKMilCommute,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,::Type{Val{true}})
du1 = zero(rate_prototype); du2 = zero(rate_prototype)
K = zero(rate_prototype); gtmp = zero(noise_rate_prototype);
L = zero(noise_rate_prototype); tmp = zero(rate_prototype)
WikJ = false .* vec(ΔW) .* vec(ΔW)'
Dg = false .* vec(ΔW) .* vec(ΔW)'
WikJ = get_WikJ(ΔW,prob,alg)
mil_correction = zero(rate_prototype)
Kj = zero(u); Dgj = zero(noise_rate_prototype)
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,WikJ,Dg,mil_correction,Kj,Dgj,tmp)
RKMilCommuteCache(u,uprev,du1,du2,K,gtmp,L,WikJ,mil_correction,Kj,Dgj,tmp)
end

struct RKMilGeneralConstantCache{WikType} <: StochasticDiffEqConstantCache
Expand Down
18 changes: 17 additions & 1 deletion src/iterated_integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,23 @@ function get_WikJ(ΔW,prob,alg)
elseif alg.ii_approx isa IICommutative
return WikJCommute_oop()
else
return KPWJ_oop #WikJGeneral_oop(ΔW)
return KPWJ_oop() #WikJGeneral_oop(ΔW)
end
end
end

function get_WikJ(ΔW,prob,alg::RKMilCommute)
if isinplace(prob)
if typeof(ΔW) <: Number || is_diagonal_noise(prob)
return WikJDiagonal_iip(ΔW)
else
return WikJCommute_iip(ΔW)
end
else
if typeof(ΔW) <: Number || is_diagonal_noise(prob)
return WikJDiagonal_oop()
else
return WikJCommute_oop()
end
end
end
Expand Down
107 changes: 93 additions & 14 deletions src/perform_step/low_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,39 +246,118 @@ end
end
end

@muladd function perform_step!(integrator,cache::RKMilCommuteConstantCache,f=integrator.f)
@unpack t,dt,uprev,u,W,p = integrator
dW = W.dW; sqdt = integrator.sqdt
Wik = cache.WikJ

ggprime_norm = 0.0

WikJ = get_iterated_I(dt, dW, W.dZ, Wik)

mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
if typeof(dW) <: Number || is_diagonal_noise(integrator.sol.prob)
WikJ = WikJ .- 1//2 .* dt
else
WikJ -= 1//2 .* UniformScaling(dt)
end
end

du1 = integrator.f(uprev,p,t)
L = integrator.g(uprev,p,t)

K = uprev + dt*du1

if is_diagonal_noise(integrator.sol.prob)
tmp = (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
gtmp = integrator.g(tmp,p,t)
Dgj = (gtmp - L)/sqdt
ggprime_norm = integrator.opts.internalnorm(Dgj,t)
u = @.. K + L*dW + Dgj*WikJ
else
for j = 1:length(dW)
if typeof(dW) <: Number
Kj = K + sqdt*L
else
Kj = K + sqdt*@view(L[:,j])
end
gtmp = integrator.g(Kj,p,t)
Dgj = (gtmp - L)/sqdt
if integrator.opts.adaptive
ggprime_norm += integrator.opts.internalnorm(Dgj,t)
end
if typeof(dW) <: Number
tmp = Dgj*WikJ
else
tmp = Dgj*@view(WikJ[:,j])
end
mil_correction += tmp
end
tmp = L*dW
u = uprev + dt*du1 + tmp + mil_correction
end
if integrator.opts.adaptive
En = integrator.opts.internalnorm(dW,t)^3*ggprime_norm^2 / 6
du2 = integrator.f(K,p,t+dt)
tmp = integrator.opts.internalnorm(integrator.opts.delta * dt * (du2 - du1) / 2,t) + En

tmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(tmp,t)

end
integrator.u = u
end

@muladd function perform_step!(integrator,cache::RKMilCommuteCache,f=integrator.f)
@unpack du1,du2,K,gtmp,L = cache
@unpack t,dt,uprev,u,W,p = integrator
@unpack WikJ,mil_correction,Kj,Dgj,tmp = cache
dW = W.dW; sqdt = integrator.sqdt
f = integrator.f; g = integrator.g

ggprime_norm = 0.0

Wik = cache.WikJ

get_iterated_I!(dt, dW, W.dZ, Wik)
WikJ = Wik.WikJ

@.. mil_correction = zero(u)
if alg_interpretation(integrator.alg) == :Ito
WikJ .= 0.5 .* (vec(dW) .* vec(dW)' .- dt .* Eye{eltype(W.dW)}(length(W.dW)))
else
WikJ .= 0.5 .* vec(dW) .* vec(dW)'
if typeof(dW) <: Number || is_diagonal_noise(integrator.sol.prob)
@.. WikJ -= 1//2*dt
else
WikJ -= 1//2 .* UniformScaling(dt)
end
end

integrator.f(du1,uprev,p,t)
integrator.g(L,uprev,p,t)

@.. K = uprev + dt*du1
for j = 1:length(dW)
@.. Kj = K + sqdt*@view(L[:,j]) # This works too
#Kj .= uprev .+ sqdt*L[:,j]
g(gtmp,Kj,p,t)

if is_diagonal_noise(integrator.sol.prob)
tmp .= (alg_interpretation(integrator.alg) == :Ito ? K : uprev) .+ integrator.sqdt .* L
integrator.g(gtmp,tmp,p,t)
@.. Dgj = (gtmp - L)/sqdt
if integrator.opts.adaptive
ggprime_norm += integrator.opts.internalnorm(Dgj,t)
ggprime_norm = integrator.opts.internalnorm(Dgj,t)
@.. u = K + L*dW + Dgj*WikJ
else
for j = 1:length(dW)
@.. Kj = K + sqdt*@view(L[:,j]) # This works too
#Kj .= uprev .+ sqdt*L[:,j]
integrator.g(gtmp,Kj,p,t)
@.. Dgj = (gtmp - L)/sqdt
if integrator.opts.adaptive
ggprime_norm += integrator.opts.internalnorm(Dgj,t)
end
mul!(tmp,Dgj,@view(WikJ[:,j]))
mil_correction .+= tmp
end
mul!(tmp,Dgj,@view(WikJ[:,j]))
mil_correction .+= tmp
mul!(tmp,L,dW)
@.. u .= uprev + dt*du1 + tmp + mil_correction
end
mul!(tmp,L,dW)
@.. u .= uprev + dt*du1 + tmp + mil_correction

if integrator.opts.adaptive
En = integrator.opts.internalnorm(W.dW,t)^3*ggprime_norm^2 / 6
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ function DiffEqBase.__init(
W = WienerProcess(t,rand_prototype,rand_prototype2,
save_everystep=save_noise,
rng = Xorshifts.Xoroshiro128Plus(_seed))
elseif alg===RKMilGeneral()
elseif typeof(alg) <: RKMilGeneral
m = length(rand_prototype)
if typeof(rand_prototype) <: Number || alg.p == nothing
rand_prototype2 = nothing
Expand Down
35 changes: 35 additions & 0 deletions test/commutative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ function f_commute(du,u,p,t)
du .+= 1.01u
end

function f_commute_oop(u,p,t)
du = A*u
du += 1.01u
end

function f_commute_analytic(u0,p,t,W)
tmp = (A+1.01I-2*(B^2))*t + B*sum(W)
exp(tmp)*u0
Expand All @@ -30,6 +35,13 @@ function σ(du,u,p,t)
du[2,4] = σ_const*u[2]
end

function σ_oop(u,p,t)
σ_const*[
u[1] u[1] u[1] u[1]
u[2] u[2] u[2] u[2]
]
end

ff_commute = SDEFunction(f_commute,σ,analytic=f_commute_analytic)

prob = SDEProblem(ff_commute,σ,u0,(0.0,1.0),noise_rate_prototype=rand(2,4))
Expand All @@ -51,3 +63,26 @@ sim2 = test_convergence(dts,prob,RKMilGeneral(ii_approx=IICommutative()),traject

sim2 = test_convergence(dts,prob,RKMilGeneral(p=2),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2


ff_commute_oop = SDEFunction(f_commute_oop,σ_oop,analytic=f_commute_analytic)

proboop = SDEProblem(ff_commute_oop,σ_oop,u0,(0.0,1.0),noise_rate_prototype=rand(2,4))

sol = solve(proboop,RKMilCommute(),dt=1/2^(8))
sol = solve(proboop,RKMilGeneral(ii_approx=IICommutative()),dt=1/2^(8))
sol = solve(proboop,RKMilGeneral(p=2),dt=1/2^(10))
sol = solve(proboop,EM(),dt=1/2^(10))

dts = (1/2) .^ (10:-1:3) #14->7 good plot
sim2 = test_convergence(dts,proboop,EM(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:final] - 0.5) < 0.2

sim2 = test_convergence(dts,proboop,RKMilCommute(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2

sim2 = test_convergence(dts,proboop,RKMilGeneral(ii_approx=IICommutative()),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2

sim2 = test_convergence(dts,proboop,RKMilGeneral(p=2),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:final] - 1.0) < 0.2
2 changes: 1 addition & 1 deletion test/gpu/sde_weak_brusselator_adaptive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ensembleprob = EnsembleProblem(prob, prob_func = prob_func)
#Performance check with nvvp
# CUDAnative.CUDAdrv.@profile
# check either on CPU with EnsembleCPUArray() or on GPU with EnsembleGPUArray()
sol = @time solve(ensembleprob,DRI1(),EnsembleCPUArray(),trajectories=numtraj)
@test_nowarn sol = @time solve(ensembleprob,DRI1(),EnsembleCPUArray(),trajectories=numtraj)
#sol = @time solve(ensembleprob,DRI1(),EnsembleGPUArray(),trajectories=numtraj)


Expand Down
19 changes: 10 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,31 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV,"APPVEYOR")
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence1")
@time @safetestset "OOP Weak Convergence Tests" begin include("weak_convergence/oop_weak.jl") end
@time @safetestset "Additive Weak Convergence Tests" begin include("weak_convergence/additive_weak.jl") end
@time @safetestset "IIP Weak Convergence Tests" begin include("weak_convergence/iip_weak.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence2")
@time @safetestset "Multidimensional IIP Weak Convergence Tests" begin include("weak_convergence/multidim_iip_weak.jl") end
@time @safetestset "Platen's PL1WM weak second order" begin include("weak_convergence/PL1WM.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence3")
if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence2")
@time @safetestset "Roessler weak SRK Tests" begin include("weak_convergence/srk_weak_final.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence4")
if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence3")
@time @safetestset "Roessler weak SRK (non-diagonal) Tests" begin include("weak_convergence/srk_weak_final_non_diagonal.jl") end
@time @safetestset "Weak Stratonovich Tests" begin include("weak_convergence/weak_strat.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence5")
if !is_APPVEYOR && (GROUP == "All" || GROUP == "WeakConvergence4")
@time @safetestset "Weak Stratonovich (non-diagonal) Tests" begin include("weak_convergence/weak_strat_non_diagonal.jl") end
@time @safetestset "SIE SME weak Tests" begin include("weak_convergence/SIE_SME.jl") end
end

if !is_APPVEYOR && GROUP == "WeakConvergence"
#activate_gpu_env()
@time @safetestset "OOP Weak Convergence Tests" begin include("weak_convergence/oop_weak.jl") end
@time @safetestset "Additive Weak Convergence Tests" begin include("weak_convergence/additive_weak.jl") end
@time @safetestset "IIP Weak Convergence Tests" begin include("weak_convergence/iip_weak.jl") end
end

if !is_APPVEYOR && GROUP == "WeakAdaptive"
activate_gpu_env()
@time @safetestset "Weak adaptive step size Brusselator " begin include("gpu/sde_weak_brusselator_adaptive.jl") end
Expand Down
6 changes: 6 additions & 0 deletions test/sde/sde_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ sim = test_convergence(dts,prob,LambaEM(),trajectories=Int(1e2))
@test abs(sim.𝒪est[:l2]-.5) < 0.2
sim2 = test_convergence(dts,prob,RKMil(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilCommute(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilGeneral(),trajectories=Int(2e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
print(".")
Expand Down Expand Up @@ -101,6 +103,8 @@ sim = test_convergence(dts,prob,ImplicitRKMil(),trajectories=Int(1e2))
@test abs(sim.𝒪est[:l2]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMil(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.22
sim2 = test_convergence(dts,prob,RKMilCommute(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.22
sim2 = test_convergence(dts,prob,RKMilGeneral(),trajectories=Int(1e2))
@test abs(sim2.𝒪est[:l∞]-1) < 0.22
print(".")
Expand Down Expand Up @@ -169,6 +173,8 @@ sim = test_convergence(dts,prob,ImplicitRKMil(),trajectories=Int(1e2))
@test abs(sim.𝒪est[:l2]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMil(),trajectories=Int(1e1))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilCommute(),trajectories=Int(1e1))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
sim2 = test_convergence(dts,prob,RKMilGeneral(),trajectories=Int(1e1))
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
print(".")
Expand Down
5 changes: 5 additions & 0 deletions test/sde/sde_linear_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ prob = prob_sde_linear
println("Solve and Plot")
sol = solve(prob,EM(),dt=1//2^(4))
sol = solve(prob,RKMil(),dt=1//2^(4))
sol = solve(prob,RKMilCommute(),dt=1//2^(4))
sol = solve(prob,RKMilGeneral(),dt=1//2^(4))
sol = solve(prob,SRI(),dt=1//2^(4))
sol = solve(prob,SRIW1(),dt=1//2^(4))
Expand All @@ -22,6 +23,8 @@ sim2 = test_convergence(dts,prob,RKMil(),trajectories=trajectories)

sim21 = test_convergence(dts,prob,RKMilGeneral(),trajectories=trajectories)

sim22 = test_convergence(dts,prob,RKMilCommute(),trajectories=trajectories)

sim3 = test_convergence(dts,prob,SRI(),trajectories=trajectories)

#TEST_PLOT && plot(plot(sim),plot(sim2),plot(sim3),layout=@layout([a b c]),size=(1200,600))
Expand All @@ -30,6 +33,8 @@ sim3 = test_convergence(dts,prob,SRI(),trajectories=trajectories)

@test abs(sim.𝒪est[:l2]-.5) + abs(sim21.𝒪est[:l∞]-1) + abs(sim3.𝒪est[:final]-1.5)<.441 #High tolerance since low dts for testing!

@test abs(sim22.𝒪est[:l∞]-1)<.3

# test reinit
integrator = init(prob,EM(),dt=1//2^(4))
solve!(integrator)
Expand Down
Loading

0 comments on commit 327bf80

Please sign in to comment.