Skip to content

Commit

Permalink
add test data
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Apr 18, 2024
1 parent 523b16b commit 4f55e58
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 10 deletions.
7 changes: 7 additions & 0 deletions src/caches/rkn_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -706,3 +706,10 @@ function alg_cache(alg::RKN4, u, rate_prototype, ::Type{uEltypeNoUnits},
end

struct RKN4ConstantCache <: OrdinaryDiffEqConstantCache end

function alg_cache(alg::RKN4, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
RKN4ConstantCache()
end
27 changes: 22 additions & 5 deletions src/perform_step/rkn_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1851,16 +1851,13 @@ end
#perform operations to find k values
k₁ = integrator.fsalfirst.x[1]
ku = uprev + halfdt * duprev + eightdtsq * k₁
print(ku)
kdu = duprev + halfdt * k₁

k₂ = f.f1(kdu, ku, p, ttmp)
ku = uprev + dt * duprev + halfdtsq * k₂
kdu = duprev + dt * k₂

k₃ = f.f1(kdu, ku, p, t + dt)
ku = uprev + dt * duprev + eightdtsq * k₃
kdu = duprev + dt * k₃

#perform final calculations to determine new y and y'.
u = uprev + sixthdtsq* (1*k₁ + 2*k₂ + 0*k₃) + dt * duprev
Expand All @@ -1872,6 +1869,17 @@ end
integrator.stats.nf2 += 1
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
#=
if integrator.opts.adaptive
uhat = dtsq * (1/6 * k₁ + 1/3 * k₂ + 0 * k₃)
duhat = dt * (1/6 * k₁ + 1/3 * k₂ + 0 * k₃)
utilde = ArrayPartition((duhat, uhat))
atmp = calculate_residuals(utilde, integrator.uprev, integrator.u,
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
=#
end

@muladd function perform_step!(integrator, cache::RKN4Cache, repeat_step = false)
Expand Down Expand Up @@ -1900,8 +1908,6 @@ end
@.. broadcast=false kdu = duprev + dt * k₂

f.f1(k₃, kdu, ku, p, t + dt)
@.. broadcast=false ku = uprev + dt * duprev + eightdtsq * k₃
@.. broadcast=false kdu = duprev + dt * k₃

#perform final calculations to determine new y and y'.
@.. broadcast=false u = uprev + sixthdtsq* (1*k₁ + 2*k₂ + 0*k₃) + dt * duprev
Expand All @@ -1912,4 +1918,15 @@ end

integrator.stats.nf += 3
integrator.stats.nf2 += 1
#=
if integrator.opts.adaptive
uhat = dtsq * (1/6 * k₁ + 1/3 * k₂ + 0 * k₃)
duhat = dt * (1/6 * k₁ + 1/3 * k₂ + 0 * k₃)
utilde = ArrayPartition((duhat, uhat))
atmp = calculate_residuals(utilde, integrator.uprev, integrator.u,
integrator.opts.abstol, integrator.opts.reltol,
integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end
=#
end
38 changes: 33 additions & 5 deletions test/algconvergence/partitioned_methods_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ sim = test_convergence(dts, prob, KahanLi8(), dense_errors = true)

sol = solve(prob, Nystrom4(), dt = 1 / 1000)

sol = solve(prob, RKN4(), dt = 1/1000)

# Nyström method
dts = 1 .// 2 .^ (9:-1:6)
sim = test_convergence(dts, prob, RKN4(), dense_errors = true)
Expand Down Expand Up @@ -179,6 +177,8 @@ sim = test_convergence(dts, prob_big, ERKN7(), dense_errors = true)
@test sim.𝒪est[:L2]4 rtol=1e-1

# Adaptive methods regression test
sol = solve(prob, RKN4(), reltol = 1e-8)
@test length(sol.u) < 16
sol = solve(prob, FineRKN4())
@test length(sol.u) < 16
sol = solve(prob, FineRKN5())
Expand Down Expand Up @@ -300,6 +300,9 @@ sim = test_convergence(dts, prob, KahanLi8(), dense_errors = true)

# Nyström method
dts = 1 .// 2 .^ (9:-1:6)
sim = test_convergence(dts, prob, RKN4(), dense_errors = true)
@test sim.𝒪est[:l2]4 rtol=1e-1
@test sim.𝒪est[:L2]4 rtol=1e-1
sim = test_convergence(dts, prob, Nystrom4(), dense_errors = true)
@test sim.𝒪est[:l2]4 rtol=1e-1
@test sim.𝒪est[:L2]4 rtol=1e-1
Expand Down Expand Up @@ -361,6 +364,8 @@ sim = test_convergence(dts, prob_big, ERKN7(), dense_errors = true)
@test sim.𝒪est[:L2]4 rtol=1e-1

# Adaptive methods regression test
sol = solve(prob, RKN4())
@test length(sol.u) < 16
sol = solve(prob, FineRKN4())
@test length(sol.u) < 16
sol = solve(prob, FineRKN5())
Expand All @@ -385,7 +390,7 @@ sol = solve(prob, ERKN7(), reltol = 1e-8)
@test length(sol.u) < 38

# Testing generalized Runge-Kutte-Nyström methods on velocity dependend ODEs with the damped oscillator
println("In Place")
println("Out of Place")

# Damped oscillator
prob = ODEProblem(
Expand Down Expand Up @@ -415,14 +420,19 @@ sim = test_convergence(dts, prob, FineRKN4(), dense_errors = true)
sim = test_convergence(dts, prob, FineRKN5(), dense_errors = true)
@test sim.𝒪est[:l2]5 rtol=1e-1
@test sim.𝒪est[:L2]4 rtol=1e-1
sim = test_convergence(dts, prob, RKN4(), dense_errors = true)
@test sim.𝒪est[:l2]4 rtol=1e-1
@test sim.𝒪est[:L2]4 rtol=1e-1

# Adaptive methods regression test

sol = solve(prob, FineRKN4())
@test length(sol.u) < 28
sol = solve(prob, FineRKN5())
@test length(sol.u) < 20

println("Out of Place")
sol = solve(prob, RKN4())
@test length(sol.u) < 30
println("In Place")
# Damped oscillator
prob = ODEProblem(
DynamicalODEFunction{true}((d_du, du, u, p, t) -> @.(d_du=-u - 0.5 * du),
Expand Down Expand Up @@ -451,6 +461,10 @@ sim = test_convergence(dts, prob, FineRKN4(), dense_errors = true)
sim = test_convergence(dts, prob, FineRKN5(), dense_errors = true)
@test sim.𝒪est[:l2]5 rtol=1e-1
@test sim.𝒪est[:L2]4 rtol=1e-1
sim = test_convergence(dts, prob, RKN4(), dense_errors = true)
@test sim.𝒪est[:l2]4 rtol=1e-1
@test sim.𝒪est[:L2]4 rtol=1e-1


# Adaptive methods regression test
sol = solve(prob, FineRKN4())
Expand Down Expand Up @@ -489,6 +503,20 @@ end
@test abs(sol_i.destats.nf - 4 * sol_i.destats.naccept) < 4
end

@testset "RKN4" begin
alg = RKN4()
dt = 0.5
# fixed time step
sol_i = solve(ode_i, alg, dt = dt)
sol_o = solve(ode_o, alg, dt = dt)
@test sol_i.t sol_o.t
@test sol_i.u sol_o.u
@test sol_i.destats.nf == sol_o.destats.nf
@test sol_i.destats.nf2 == sol_o.destats.nf2
@test sol_i.destats.naccept == sol_o.destats.naccept
@test 19 <= sol_i.destats.naccept <= 21
@test abs(sol_i.destats.nf - 4 * sol_i.destats.naccept) < 4
end
@testset "FineRKN4" begin
alg = FineRKN4()
dt = 0.5
Expand Down

0 comments on commit 4f55e58

Please sign in to comment.