Skip to content

Commit

Permalink
Reconstruct Broadcasted in kernel to help Enzyme.jl (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jun 25, 2024
1 parent 800c237 commit 8c5d550
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,39 +47,61 @@ end
@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest

# to help Enzyme.jl, we won't pass the broadcasted object directly
# but instead pass its arguments and reconstruct the object device-side
bc = Broadcast.preprocess(dest, bc)
bcstyle = @static if VERSION >= v"1.10-"
bc.style
else
typeof(BroadcastStyle(typeof(bc)))
end

broadcast_kernel = if ndims(dest) == 1 ||
(isa(IndexStyle(dest), IndexLinear) &&
isa(IndexStyle(bc), IndexLinear))
function (ctx, dest, bc, nelem)
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
bc′ = @static if VERSION >= v"1.10-"
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end

i = 1
while i <= nelem
I = @linearidx(dest, i)
@inbounds dest[I] = bc[I]
@inbounds dest[I] = bc[I]
i += 1
end
return
end
else
function (ctx, dest, bc, nelem)
function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...)
bc′ = @static if VERSION >= v"1.10-"
Broadcasted(bcstyle, bcf, bcargs, bcaxes)
else
Broadcasted{bcstyle}(bcf, bcargs, bcaxes)
end

i = 0
while i < nelem
i += 1
I = @cartesianidx(dest, i)
@inbounds dest[I] = bc[I]
@inbounds dest[I] = bc[I]
end
return
end
end

elements = length(dest)
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1;
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1,
bcstyle, bc.f, bc.axes, bc.args...;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread;
gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int,
bcstyle, bc.f, bc.axes, bc.args...;
threads=config.threads, blocks=config.blocks)

if eltype(dest) <: BrokenBroadcast
Expand Down

0 comments on commit 8c5d550

Please sign in to comment.