-
-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU broadcast regression with ArrayFuse on recent CUDA #1626
Comments
Here's a version without OrdinaryDiffEq.jl: """
ArrayFuse{AT, T, P} <: AbstractArray{T, 1}
GPU Friendly type to wrap around two arrays - `visible` and `hidden`, for which when we `setindex!` some value `v` at index `i`
we get
visible[i] = p[1] * visible[i] + p[2] * v
hidden[i] = hidden[i] + p[3] * visible[i]
where p is a parameter tuple of size 3.
"""
struct ArrayFuse{AT,T,P} <: AbstractArray{T,1}
visible::AT
hidden::AT
p::P
end
ArrayFuse(visible::AT, hidden::AT, p) where {AT} = ArrayFuse{AT,eltype(visible),typeof(p)}(visible, hidden, p)
@inline function Base.copyto!(af::ArrayFuse{AT,T,P}, src::Base.Broadcast.Broadcasted) where {AT,T,P}
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
end
@inline function Base.copyto!(af::ArrayFuse{AT,T,P}, src::AbstractArray) where {AT,T,P}
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
end
@inline function Base.copyto!(af::ArrayFuse{AT,T,P}, src::Base.Broadcast.Broadcasted) where {AT,T,P}
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
end
@inline function Base.copyto!(af::ArrayFuse{AT,T,P}, src::Base.Broadcast.Broadcasted{F1,Axes,F,Args}) where {AT,T,P,F1<:Base.Broadcast.AbstractArrayStyle{0},Axes,F,Args<:Tuple}
@. af.visible = af.p[1] * af.visible + af.p[2] * src
@. af.hidden = af.hidden + af.p[3] * af.visible
end
# not recommended but good to have
@inline function Base.getindex(af::ArrayFuse, index)
return af.visible[index]
end
@inline function Base.setindex!(af::ArrayFuse, value, index)
af.visible[index] = af.p[1] * af.visible[index] + af.p[2] * value
af.hidden[index] = muladd(af.p[3], af.visible[index], af.hidden[index])
end
@inline Base.size(af::ArrayFuse) = length(af.visible)
@inline Base.axes(af::ArrayFuse) = axes(af.visible)
using CUDA
CUDA.allowscalar(false)
N = 256
# Define the initial condition as normal arrays
u0 = zeros(N, N, 3)
u0 .= 1.0
gu0 = CuArray(Float32.(u0))
tmp, u, a, b = [copy(gu0) for i in 1:4]
dt = 0.01
du = ArrayFuse(tmp, u, (a, dt, b))
du .= u ERROR: This object is not a GPU array
Stacktrace:
[1] error(s::String)
@ Base .\error.jl:33
[2] backend(#unused#::Type)
@ GPUArrays C:\Users\accou\.julia\packages\GPUArrays\VNhDf\src\device\execution.jl:15
[3] backend(x::ArrayFuse{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float32, Tuple{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float64, CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}})
@ GPUArrays C:\Users\accou\.julia\packages\GPUArrays\VNhDf\src\device\execution.jl:16
[4] _copyto!
@ C:\Users\accou\.julia\packages\GPUArrays\VNhDf\src\host\broadcast.jl:73 [inlined]
[5] materialize!
@ C:\Users\accou\.julia\packages\GPUArrays\VNhDf\src\host\broadcast.jl:51 [inlined]
[6] materialize!(dest::ArrayFuse{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float32, Tuple{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}, Float64, CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}}, bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{3}, Nothing, typeof(identity), Tuple{CuArray{Float32, 3, CUDA.Mem.DeviceBuffer}}})
@ Base.Broadcast .\broadcast.jl:868
[7] top-level scope
@ c:\Users\accou\OneDrive\Computer\Desktop\test.jl:329 |
@maleadt can you help us figure out what changed in CUDA broadcast so we can override this? I think the solution might be to override |
Maybe JuliaGPU/GPUArrays.jl#393? |
MWE at JuliaGPU/GPUArrays.jl#404 |
ChrisRackauckas
added a commit
that referenced
this issue
Mar 27, 2022
Label tests broken due to #1626
ChrisRackauckas
added a commit
that referenced
this issue
May 15, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
MWE:
The text was updated successfully, but these errors were encountered: