Skip to content

Commit

Permalink
Revert "Reconstruct Broadcasted in kernel to help Enzyme.jl (JuliaGPU…
Browse files Browse the repository at this point in the history
…#539)"

This reverts commit 8c5d550.
  • Loading branch information
maleadt committed Jun 28, 2024
1 parent cd1f59a commit 40fa8c0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GPUArrays"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "10.2.1"
version = "10.2.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
34 changes: 6 additions & 28 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,61 +47,39 @@ end
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

# to help Enzyme.jl, we won't pass the broadcasted object directly
# but instead pass its arguments and reconstruct the object device-side
bc = Broadcast.preprocess(dest, bc)
bcstyle = @static if VERSION >= v"1.10-"
bc.style
else
typeof(BroadcastStyle(typeof(bc)))
end

broadcast_kernel = if ndims(dest) == 1 ||
(isa(IndexStyle(dest), IndexLinear) &&
isa(IndexStyle(bc), IndexLinear))
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
bc′ = @static if VERSION >= v"1.10-"
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end

function (ctx, dest, bc, nelem)
i = 1
while i <= nelem
I = @linearidx(dest, i)
@inbounds dest[I] = bc[I]
@inbounds dest[I] = bc[I]
i += 1
end
return
end
else
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
bc′ = @static if VERSION >= v"1.10-"
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end

function (ctx, dest, bc, nelem)
i = 0
while i < nelem
i += 1
I = @cartesianidx(dest, i)
@inbounds dest[I] = bc[I]
@inbounds dest[I] = bc[I]
end
return
end
end

elements = length(dest)
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1,
bcstyle, bc.f, bc.axes, bc.args...;
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int,
bcstyle, bc.f, bc.axes, bc.args...;
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)

if eltype(dest) <: BrokenBroadcast
Expand Down

0 comments on commit 40fa8c0

Please sign in to comment.