diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index a78e6c7c..59ce6207 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -52,17 +52,19 @@ end broadcast_kernel = if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) - function (ctx, dest, bc, nelem) + function (ctx, dest, bcstyle, bcf, bcaxes, nelem, bcargs...) + bc2 = Base.Broadcast.Broadcasted(bcstyle, bcf, bcargs, bcaxes) i = 1 while i <= nelem I = @linearidx(dest, i) - @inbounds dest[I] = bc[I] + @inbounds dest[I] = bc2[I] i += 1 end return end else - function (ctx, dest, bc, nelem) + function (ctx, dest, bcstyle, bcf, bcaxes, nelem, bcargs...) + bc2 = Base.Broadcast.Broadcasted(bcstyle, bcf, bcargs, bcaxes) i = 0 while i < nelem i += 1 @@ -75,11 +77,11 @@ end elements = length(dest) elements_per_thread = typemax(Int) - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc.style, bc.f, bc.axes, 1, bc.args...; elements, elements_per_thread) config = launch_configuration(backend(dest), heuristic; elements, elements_per_thread) - gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; + gpu_call(broadcast_kernel, dest, bc.style, bc.f, bc.axes, config.elements_per_thread, bc.args...; threads=config.threads, blocks=config.blocks) if eltype(dest) <: BrokenBroadcast