From 31002ff379a9cbbdb60f1298ae4dde1a727ce7c0 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 16 May 2024 20:04:01 +0100 Subject: [PATCH 1/6] avoid scalar indexing bug if diff_vars is filtrues --- src/dense/generic_dense.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 1a45aa5855..a37dcdf75e 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -748,6 +748,14 @@ end Θ * dt * k[2]) end +@muladd function hermite_interpolant!( + out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}, differential_vars::FillArrays.Trues{1, Tuple{Base.OneTo{Int}}}) # Default interpolant is Hermite +@inbounds @.. broadcast=false out=(1 - Θ) * y₀ + Θ * y₁ + + Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + + Θ * dt * k[2]) +end + @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite @inbounds @simd ivdep for i in eachindex(out) From f46750aea4f77dedff69b7638c73ee095ea82cba Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Fri, 17 May 2024 16:34:41 +0100 Subject: [PATCH 2/6] test --- test/gpu/Project.toml | 5 +++++ test/gpu/hermite_test.jl | 15 +++++++++++++++ test/runtests.jl | 1 + 3 files changed, 21 insertions(+) create mode 100644 test/gpu/hermite_test.jl diff --git a/test/gpu/Project.toml b/test/gpu/Project.toml index 39a8540125..b42a597fa8 100644 --- a/test/gpu/Project.toml +++ b/test/gpu/Project.toml @@ -1,6 +1,11 @@ [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" [compat] CUDA = "5" diff --git a/test/gpu/hermite_test.jl b/test/gpu/hermite_test.jl new file mode 100644 index 0000000000..13c3f2cac7 --- /dev/null +++ b/test/gpu/hermite_test.jl @@ -0,0 +1,15 @@ +using ComponentArrays, CUDA, Adapt, RecursiveArrayTools, FastBroadcast, FillArrays, OrdinaryDiffEq, Test + +a = ComponentArray((a=rand(Float32, 5,5), b=rand(Float32, 5, 5))) +a = adapt(CuArray, a) +pa = ArrayPartition(a) +pb = deepcopy(pa) +pc = deepcopy(pa) +pd = deepcopy(pa) +pe = deepcopy(pa) +k = [pd, pe] +t = FillArrays.Trues(length(pa)) + +OrdinaryDiffEq.hermite_interpolant!(pa, 0.1, 0.2, pb, pc, k, nothing, Val{0}, t) # if this doesnt error we're good + +@test pa.a != pb.a \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 6cdf509647..6ee7b513e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -192,5 +192,6 @@ end @time @safetestset "Autoswitch GPU" include("gpu/autoswitch.jl") @time @safetestset "Linear LSRK GPU" include("gpu/linear_lsrk.jl") @time @safetestset "Reaction-Diffusion Stiff Solver GPU" include("gpu/reaction_diffusion_stiff.jl") + @time @safetestset "Scalar indexing bug bypass" include("gpu/hermite_test.jl") end end # @time From b3fc88c3cf35b10fb3666920eb222a4724e9b0ca Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 19 May 2024 12:54:40 -0400 Subject: [PATCH 3/6] Complete the Hermites --- src/dense/generic_dense.jl | 358 ++++++++++++++++++++++++------------- test/gpu/hermite_test.jl | 7 +- 2 files changed, 237 insertions(+), 128 deletions(-) diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index a37dcdf75e..99f4f7c9ff 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -700,7 +700,7 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{0}}, differential_vars) #@.. broadcast=false (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2]) if all(differential_vars) @inbounds (1 - Θ) * y₀ + Θ * y₁ + @@ -714,15 +714,21 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{0}}, differential_vars) #@.. broadcast=false (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2]) - @inbounds @.. broadcast=false (1 - Θ)*y₀+Θ*y₁+ - differential_vars*Θ*(Θ-1)* - ((1 - 2Θ)*(y₁ - y₀)+(Θ-1)*dt*k[1]+Θ*dt*k[2]) + if all(differential_vars) + @inbounds @.. broadcast=false (1 - Θ)*y₀+Θ*y₁+ + Θ*(Θ-1)* + ((1 - 2Θ)*(y₁ - y₀)+(Θ-1)*dt*k[1]+Θ*dt*k[2]) + else + @inbounds @.. broadcast=false (1 - Θ)*y₀+Θ*y₁+ + differential_vars*Θ*(Θ-1)* + ((1 - 2Θ)*(y₁ - y₀)+(Θ-1)*dt*k[1]+Θ*dt*k[2]) + end end @muladd function hermite_interpolant(Θ, dt, y₀::Array, y₁, k, ::Type{Val{true}}, - idxs::Nothing, T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite + idxs::Nothing, T::Type{Val{0}}, differential_vars) out = similar(y₀) @inbounds @simd ivdep for i in eachindex(y₀) out[i] = (1 - Θ) * y₀[i] + Θ * y₁[i] + @@ -732,32 +738,39 @@ end end @muladd function hermite_interpolant( - Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite + Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{0}}, differential_vars) # return @.. broadcast=false (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*(Θ-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+(Θ-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs]) - return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + - differential_vars .* (Θ * (Θ - 1) * - ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] + - Θ * dt * k[2][idxs])) -end - -@muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite - @inbounds @.. broadcast=false out=(1 - Θ) * y₀ + Θ * y₁ + - differential_vars * Θ * (Θ - 1) * - ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + - Θ * dt * k[2]) + if all(differential_vars) + return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + + (Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] + + Θ * dt * k[2][idxs])) + else + return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + + differential_vars[idxs] .* (Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] + + Θ * dt * k[2][idxs])) + end end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}, differential_vars::FillArrays.Trues{1, Tuple{Base.OneTo{Int}}}) # Default interpolant is Hermite -@inbounds @.. broadcast=false out=(1 - Θ) * y₀ + Θ * y₁ + - Θ * (Θ - 1) * - ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + - Θ * dt * k[2]) + out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{0}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false out=(1 - Θ) * y₀ + Θ * y₁ + + Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + + Θ * dt * k[2]) + else + @inbounds @.. broadcast=false out=(1 - Θ) * y₀ + Θ * y₁ + + differential_vars * Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁ - y₀) + (Θ - 1) * dt * k[1] + + Θ * dt * k[2]) + end + out end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{0}}, differential_vars) @inbounds @simd ivdep for i in eachindex(out) out[i] = (1 - Θ) * y₀[i] + Θ * y₁[i] + differential_vars[i] * Θ * (Θ - 1) * @@ -767,15 +780,23 @@ end end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite - @views @.. broadcast=false out=(1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + - differential_vars * Θ * (Θ - 1) * - ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + - (Θ - 1) * dt * k[1][idxs] + Θ * dt * k[2][idxs]) + out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}, differential_vars) + if all(differential_vars) + @views @.. broadcast=false out=(1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + + Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + + (Θ - 1) * dt * k[1][idxs] + Θ * dt * k[2][idxs]) + else + @views @.. broadcast=false out=(1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + + differential_vars * Θ * (Θ - 1) * + ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + + (Θ - 1) * dt * k[1][idxs] + Θ * dt * k[2][idxs]) + end + out end @muladd function hermite_interpolant!( - out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite + out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{0}}, differential_vars) @inbounds for (j, i) in enumerate(idxs) out[j] = (1 - Θ) * y₀[i] + Θ * y₁[i] + differential_vars[j] * Θ * (Θ - 1) * @@ -788,7 +809,7 @@ end Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{1}}, differential_vars) #@.. broadcast=false k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt if all(differential_vars) @inbounds ( @@ -805,41 +826,69 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite - @inbounds @.. broadcast=false !differential_vars * - ((y₁ - y₀) / - dt)+differential_vars * ( - k[1] + - Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + - Θ * - (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + - 6 * y₁) / dt) + T::Type{Val{1}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false !differential_vars * + ((y₁ - y₀) / + dt)+( + k[1] + + Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + + 6 * y₁) / dt) + else + @inbounds @.. broadcast=false !differential_vars * + ((y₁ - y₀) / + dt)+differential_vars * ( + k[1] + + Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + + 6 * y₁) / dt) + end end @muladd function hermite_interpolant( - Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite - # return @.. broadcast=false k[1][idxs] + Θ*(-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(3*dt*k[1][idxs] + 3*dt*k[2][idxs] + 6*y₀[idxs] - 6*y₁[idxs]) + 6*y₁[idxs])/dt - return (.!differential_vars) .* ((y₁[idxs] - y₀[idxs]) / dt) + - differential_vars .* ( - k[1][idxs] + - Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + - Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + 6 * y₀[idxs] - 6 * y₁[idxs]) + - 6 * y₁[idxs]) / dt) + Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{1}}, differential_vars) + if all(differential_vars) + ( + k[1][idxs] + + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + + Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + 6 * y₀[idxs] - 6 * y₁[idxs]) + + 6 * y₁[idxs]) / dt) + else + (.!differential_vars[idxs]) .* ((y₁[idxs] - y₀[idxs]) / dt) + + differential_vars[idxs] .* ( + k[1][idxs] + + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + + Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + 6 * y₀[idxs] - 6 * y₁[idxs]) + + 6 * y₁[idxs]) / dt) + end end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite - @inbounds @.. broadcast=false out=!differential_vars * ((y₁ - y₀) / dt) + - differential_vars * ( - k[1] + - Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + - Θ * - (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + - 6 * y₁) / dt) + out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{1}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false out=( + k[1] + + Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + + 6 * y₁) / dt) + else + @inbounds @.. broadcast=false out=!differential_vars * ((y₁ - y₀) / dt) + + differential_vars * ( + k[1] + + Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + + 6 * y₁) / dt) + end + out end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{1}}, differential_vars) @inbounds @simd ivdep for i in eachindex(out) out[i] = !differential_vars[i] * ((y₁[i] - y₀[i]) / dt) + differential_vars[i] * ( @@ -852,18 +901,27 @@ end end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite - @views @.. broadcast=false out=!differential_vars * ((y₁ - y₀) / dt) + - differential_vars * ( - k[1][idxs] + - Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - - 6 * y₀[idxs] + - Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + - 6 * y₀[idxs] - 6 * y₁[idxs]) + 6 * y₁[idxs]) / dt) + out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}, differential_vars) + if all(differential_vars) + @views @.. broadcast=false out=( + k[1][idxs] + + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - + 6 * y₀[idxs] + + Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + + 6 * y₀[idxs] - 6 * y₁[idxs]) + 6 * y₁[idxs]) / dt) + else + @views @.. broadcast=false out=!differential_vars * ((y₁ - y₀) / dt) + + differential_vars * ( + k[1][idxs] + + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - + 6 * y₀[idxs] + + Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + + 6 * y₀[idxs] - 6 * y₁[idxs]) + 6 * y₁[idxs]) / dt) + end end @muladd function hermite_interpolant!( - out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite + out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{1}}, differential_vars) @inbounds for (j, i) in enumerate(idxs) out[j] = !differential_vars[j] * ((y₁[i] - y₀[i]) / dt) + differential_vars[j] * ( @@ -879,8 +937,7 @@ end Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite - #@.. broadcast=false (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt) + T::Type{Val{2}}, differential_vars) if all(differential_vars) @inbounds (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / @@ -893,36 +950,57 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite - #@.. broadcast=false (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt) - @inbounds @.. broadcast=false differential_vars * - (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + - Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + - 6 * y₁)/(dt * dt) + T::Type{Val{2}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + + 6 * y₁)/(dt * dt) + else + @inbounds @.. broadcast=false differential_vars * + (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + + 6 * y₁)/(dt * dt) + end end @muladd function hermite_interpolant( - Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite - #out = similar(y₀,axes(idxs)) - #@views @.. broadcast=false out = (-4*dt*k[1][idxs] - 2*dt*k[2][idxs] - 6*y₀[idxs] + Θ*(6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs]) + 6*y₁[idxs])/(dt*dt) - @views out = differential_vars .* - (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + - Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - - 12 * y₁[idxs]) + 6 * y₁[idxs]) / (dt * dt) + Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{2}}, differential_vars) + if all(differential_vars) + @views out = (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - + 12 * y₁[idxs]) + 6 * y₁[idxs]) / (dt * dt) + else + @views out = differential_vars[idxs] .* + (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - + 12 * y₁[idxs]) + 6 * y₁[idxs]) / (dt * dt) + end out end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite - @inbounds @.. broadcast=false out=differential_vars * - (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + - Θ * - (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + - 6 * y₁) / (dt * dt) + out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{2}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false out=(-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - + 12 * y₁) + + 6 * y₁) / (dt * dt) + else + @inbounds @.. broadcast=false out=differential_vars * + (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ + + Θ * + (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - + 12 * y₁) + + 6 * y₁) / (dt * dt) + end + out end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{2}}, differential_vars) @inbounds @simd ivdep for i in eachindex(out) out[i] = differential_vars[i] * (-4 * dt * k[1][i] - 2 * dt * k[2][i] - 6 * y₀[i] + @@ -933,17 +1011,26 @@ end end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite - @views @.. broadcast=false out=differential_vars * - (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - - 6 * y₀[idxs] + - Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + - 12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / - (dt * dt) + out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}, differential_vars) + if all(differential_vars) + @views @.. broadcast=false out=(-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - + 6 * y₀[idxs] + + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + + 12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / + (dt * dt) + else + @views @.. broadcast=false out=differential_vars[idxs] * + (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - + 6 * y₀[idxs] + + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + + 12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / + (dt * dt) + end + out end @muladd function hermite_interpolant!( - out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite + out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}, differential_vars) @inbounds for (j, i) in enumerate(idxs) out[j] = differential_vars[j] * (-4 * dt * k[1][i] - 2 * dt * k[2][i] - 6 * y₀[i] + @@ -957,7 +1044,7 @@ end Herimte Interpolation, chosen if no other dispatch for ode_interpolant """ @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing, - T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{3}}, differential_vars) #@.. broadcast=false (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt) if all(differential_vars) @inbounds (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt) @@ -968,39 +1055,53 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant end @muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing, - T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite - #@.. broadcast=false (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt) - @inbounds @.. broadcast=false differential_vars * - (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - - 12 * y₁)/(dt * - dt * - dt) + T::Type{Val{3}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - + 12 * y₁)/(dt * + dt * + dt) + else + @inbounds @.. broadcast=false differential_vars * + (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - + 12 * y₁)/(dt * + dt * + dt) + end end @muladd function hermite_interpolant( - Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite - #out = similar(y₀,axes(idxs)) - #@views @.. broadcast=false out = (6*dt*k[1][idxs] + 6*dt*k[2][idxs] + 12*y₀[idxs] - 12*y₁[idxs])/(dt*dt*dt) - @views out = differential_vars .* - (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - - 12 * y₁[idxs]) / - (dt * dt * dt) + Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{3}}, differential_vars) + if all(differential_vars) + @views out = (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - + 12 * y₁[idxs]) / + (dt * dt * dt) + else + @views out = differential_vars[idxs] .* + (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - + 12 * y₁[idxs]) / + (dt * dt * dt) + end out end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite - @inbounds @.. broadcast=false out=differential_vars * - (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / - (dt * dt * dt) - #for i in eachindex(out) - # out[i] = (6*dt*k[1][i] + 6*dt*k[2][i] + 12*y₀[i] - 12*y₁[i])/(dt*dt*dt) - #end - #out + out, Θ, dt, y₀, y₁, k, idxs::Nothing, T::Type{Val{3}}, differential_vars) + if all(differential_vars) + @inbounds @.. broadcast=false out=(6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - + 12 * y₁) / + (dt * dt * dt) + else + @inbounds @.. broadcast=false out=differential_vars * + (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - + 12 * y₁) / + (dt * dt * dt) + end + out end @muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs::Nothing, - T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite + T::Type{Val{3}}, differential_vars) @inbounds @simd ivdep for i in eachindex(out) out[i] = differential_vars[i] * (6 * dt * k[1][i] + 6 * dt * k[2][i] + 12 * y₀[i] - 12 * y₁[i]) / @@ -1010,13 +1111,20 @@ end end @muladd function hermite_interpolant!( - out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite - @views @.. broadcast=false out=(6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + - 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) + out, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}, differential_vars) + if all(differential_vars) + @views @.. broadcast=false out=(6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + + 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) + else + @views @.. broadcast=false out=differential_vars[idxs] * + (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + + 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) + end + out end @muladd function hermite_interpolant!( - out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite + out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{3}}, differential_vars) @inbounds for (j, i) in enumerate(idxs) out[j] = differential_vars[j] * (6 * dt * k[1][i] + 6 * dt * k[2][i] + 12 * y₀[i] - 12 * y₁[i]) / diff --git a/test/gpu/hermite_test.jl b/test/gpu/hermite_test.jl index 13c3f2cac7..1073d8fd4a 100644 --- a/test/gpu/hermite_test.jl +++ b/test/gpu/hermite_test.jl @@ -1,6 +1,7 @@ -using ComponentArrays, CUDA, Adapt, RecursiveArrayTools, FastBroadcast, FillArrays, OrdinaryDiffEq, Test +using ComponentArrays, CUDA, Adapt, RecursiveArrayTools, FastBroadcast, FillArrays, + OrdinaryDiffEq, Test -a = ComponentArray((a=rand(Float32, 5,5), b=rand(Float32, 5, 5))) +a = ComponentArray((a = rand(Float32, 5, 5), b = rand(Float32, 5, 5))) a = adapt(CuArray, a) pa = ArrayPartition(a) pb = deepcopy(pa) @@ -12,4 +13,4 @@ t = FillArrays.Trues(length(pa)) OrdinaryDiffEq.hermite_interpolant!(pa, 0.1, 0.2, pb, pc, k, nothing, Val{0}, t) # if this doesnt error we're good -@test pa.a != pb.a \ No newline at end of file +@test pa.a != pb.a From 6761f5a50f7b4ca6df6e1f9b99d3e46886d00869 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 19 May 2024 13:12:19 -0400 Subject: [PATCH 4/6] fix some algebraic parts --- src/dense/generic_dense.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 99f4f7c9ff..7b24a3df06 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -1077,7 +1077,7 @@ end 12 * y₁[idxs]) / (dt * dt * dt) else - @views out = differential_vars[idxs] .* + @views out = differential_vars .* (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) From a3892ab165fbddc7208459dc316223d33741f258 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 19 May 2024 13:57:11 -0400 Subject: [PATCH 5/6] fix indexing --- src/dense/generic_dense.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index 7b24a3df06..52bd2126e3 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -747,7 +747,7 @@ end Θ * dt * k[2][idxs])) else return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] + - differential_vars[idxs] .* (Θ * (Θ - 1) * + differential_vars .* (Θ * (Θ - 1) * ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] + Θ * dt * k[2][idxs])) end @@ -857,8 +857,8 @@ end Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + 6 * y₀[idxs] - 6 * y₁[idxs]) + 6 * y₁[idxs]) / dt) else - (.!differential_vars[idxs]) .* ((y₁[idxs] - y₀[idxs]) / dt) + - differential_vars[idxs] .* ( + (.!differential_vars) .* ((y₁[idxs] - y₀[idxs]) / dt) + + differential_vars .* ( k[1][idxs] + Θ * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + Θ * (3 * dt * k[1][idxs] + 3 * dt * k[2][idxs] + 6 * y₀[idxs] - 6 * y₁[idxs]) + @@ -972,7 +972,7 @@ end Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / (dt * dt) else - @views out = differential_vars[idxs] .* + @views out = differential_vars .* (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / (dt * dt) @@ -1019,7 +1019,7 @@ end 12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / (dt * dt) else - @views @.. broadcast=false out=differential_vars[idxs] * + @views @.. broadcast=false out=differential_vars * (-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - 6 * y₀[idxs] + Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + @@ -1116,7 +1116,7 @@ end @views @.. broadcast=false out=(6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) else - @views @.. broadcast=false out=differential_vars[idxs] * + @views @.. broadcast=false out=differential_vars * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + 12 * y₀[idxs] - 12 * y₁[idxs]) / (dt * dt * dt) end From 9fce5e856abc5bec8ccb1f9c093bd514bdd4cf5d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 19 May 2024 14:17:07 -0400 Subject: [PATCH 6/6] fix gpu test --- test/gpu/hermite_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/gpu/hermite_test.jl b/test/gpu/hermite_test.jl index 1073d8fd4a..db7d233884 100644 --- a/test/gpu/hermite_test.jl +++ b/test/gpu/hermite_test.jl @@ -13,4 +13,4 @@ t = FillArrays.Trues(length(pa)) OrdinaryDiffEq.hermite_interpolant!(pa, 0.1, 0.2, pb, pc, k, nothing, Val{0}, t) # if this doesnt error we're good -@test pa.a != pb.a +@test pa.x[1] != pb.x[1]