Skip to content

Commit

Permalink
Merge pull request #2201 from SciML/dprkn6_interp
Browse files Browse the repository at this point in the history
Fix and test DPRKN interpolation with idxs
  • Loading branch information
ChrisRackauckas authored May 19, 2024
2 parents 1f0e158 + 2e7cb8c commit dd9f8cf
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 25 deletions.
65 changes: 45 additions & 20 deletions src/dense/interpolants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2947,6 +2947,28 @@ end
b4Θ * k4[idxs] + b5Θ * k5[idxs] + b6Θ * k6[idxs])))
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{DPRKN6ConstantCache, DPRKN6Cache}, idxs::Number,
T::Type{Val{0}}, differential_vars::Nothing)
@dprkn6pre0
halfsize = length(y₀) ÷ 2
if idxs <= halfsize
duprev[idxs] +
dt * Θ *
(bp1Θ * k1[idxs] + bp3Θ * k3[idxs] +
bp4Θ * k4[idxs] + bp5Θ * k5[idxs] + bp6Θ * k6[idxs])
else
idxs = idxs - halfsize
uprev[idxs] +
dt * Θ *
(duprev[idxs] +
dt * Θ *
(b1Θ * k1[idxs] +
b3Θ * k3[idxs] +
b4Θ * k4[idxs] + b5Θ * k5[idxs] + b6Θ * k6[idxs]))
end
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{DPRKN6ConstantCache, DPRKN6Cache},
idxs::Nothing, T::Type{Val{0}}, differential_vars::Nothing)
Expand Down Expand Up @@ -2976,25 +2998,28 @@ end
cache::Union{DPRKN6ConstantCache, DPRKN6Cache}, idxs,
T::Type{Val{0}}, differential_vars::Nothing)
@dprkn6pre0
@inbounds @.. broadcast=false out.x[2]=uprev[idxs] +
dt * Θ *
(duprev[idxs] +
dt * Θ *
(b1Θ * k1[idxs] +
b3Θ * k3[idxs] +
b4Θ * k4[idxs] + b5Θ * k5[idxs] +
b6Θ * k6[idxs]))
@inbounds @.. broadcast=false out.x[1]=duprev[idxs] +
dt * Θ *
(bp1Θ * k1[idxs] + bp3Θ * k3[idxs] +
bp4Θ * k4[idxs] + bp5Θ * k5[idxs] +
bp6Θ * k6[idxs])
#for (j,i) in enumerate(idxs)
# out.x[2][j] = uprev[i] + dt*Θ*(duprev[i] + dt*Θ*(b1Θ*k1[i] +
# b3Θ*k3[i] +
# b4Θ*k4[i] + b5Θ*k5[i] + b6Θ*k6[i]))
# out.x[1][j] = duprev[i] + dt*Θ*(bp1Θ*k1[i] + bp3Θ*k3[i] +
# bp4Θ*k4[i] + bp5Θ*k5[i] + bp6Θ*k6[i])
#end
halfsize = length(y₀) ÷ 2
isfirsthalf = idxs .<= halfsize
secondhalf = idxs .> halfsize
firstidxs = idxs[isfirsthalf]
secondidxs_shifted = idxs[secondhalf]
secondidxs = secondidxs_shifted .- halfsize

@views @.. broadcast=false out[secondhalf]=uprev[secondidxs] +
dt * Θ *
(duprev[secondidxs] +
dt * Θ *
(b1Θ * k1[secondidxs] +
b3Θ * k3[secondidxs] +
b4Θ * k4[secondidxs] +
b5Θ * k5[secondidxs] +
b6Θ * k6[secondidxs]))
@views @.. broadcast=false out[isfirsthalf]=duprev[firstidxs] +
dt * Θ *
(bp1Θ * k1[firstidxs] +
bp3Θ * k3[firstidxs] +
bp4Θ * k4[firstidxs] +
bp5Θ * k5[firstidxs] +
bp6Θ * k6[firstidxs])
out
end
3 changes: 2 additions & 1 deletion src/nlsolve/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ function build_nlsolver(
end
prob = NonlinearProblem(NonlinearFunction(nlf), copy(ztmp), nlp_params)
cache = init(prob, nlalg.alg)
nlcache = NonlinearSolveCache(nothing, tstep, nothing, nothing, invγdt, prob, cache)
nlcache = NonlinearSolveCache(
nothing, tstep, nothing, nothing, invγdt, prob, cache)
else
nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf,
invγdt, tType(nlalg.new_W_dt_cutoff), t)
Expand Down
35 changes: 31 additions & 4 deletions test/interface/interpolation_output_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ res1 = solve(prob, Vern8(), dt = 1 / 10, saveat = 1 / 10)
res3 = solve(prob, CalvoSanz4(), dt = 1 / 10, saveat = 1 / 10)

sol = solve(prob, CalvoSanz4(), dt = 1 / 10)
@test sol(0.32) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32, Val{1}) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32, Val{2}) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32, Val{3}) isa OrdinaryDiffEq.ArrayPartition
@test sol(0.32) isa RecursiveArrayTools.ArrayPartition
@test sol(0.32, Val{1}) isa RecursiveArrayTools.ArrayPartition
@test sol(0.32, Val{2}) isa RecursiveArrayTools.ArrayPartition
@test sol(0.32, Val{3}) isa RecursiveArrayTools.ArrayPartition

function f(du, u, p, t)
du .= u
Expand All @@ -40,3 +40,30 @@ sol(0:0.1:100; idxs = [1, 2])
@test sol(0:0.1:100) isa DiffEqArray
@test length(sol(0:0.1:100)) == length(0:0.1:100)
@test length(sol(0:0.1:100).u[1]) == 3

## Test DPRKN Interpolation

#Parameters
ω = 1

#Initial Conditions
x₀ = [0.0]
dx₀ =/ 2]
tspan = (0.0, 2π)

ϕ = atan((dx₀[1] / ω) / x₀[1])
A = (x₀[1]^2 + dx₀[1]^2)

function harmonicoscillator(ddu, du, u, ω, t)
ddu .= -ω^2 * u
end

prob = SecondOrderODEProblem(harmonicoscillator, dx₀, x₀, tspan, ω)
sol = solve(prob, DPRKN6())
@test sol(0.5) isa RecursiveArrayTools.ArrayPartition
@test sol(0.5; idxs = 1) isa Number
@test sol(0.5; idxs = [1]) isa Vector
@test sol(0.5; idxs = [1, 2]) isa Vector
@test Vector(sol(0.5)) == sol(0.5; idxs = [1, 2])
@test reverse(Vector(sol(0.5))) == sol(0.5; idxs = [2, 1])
@test Vector(sol(0.5)) == [sol(0.5; idxs = 1); sol(0.5; idxs = 2)]

0 comments on commit dd9f8cf

Please sign in to comment.