diff --git a/Project.toml b/Project.toml index f24332fb..85e19593 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUArrays" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.2.1" +version = "10.2.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 79aa9759..a78e6c7c 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -47,47 +47,27 @@ end @inline function _copyto!(dest::AbstractArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest - - # to help Enzyme.jl, we won't pass the broadcasted object directly - # but instead pass its arguments and reconstruct the object device-side bc = Broadcast.preprocess(dest, bc) - bcstyle = @static if VERSION >= v"1.10-" - bc.style - else - typeof(BroadcastStyle(typeof(bc))) - end broadcast_kernel = if ndims(dest) == 1 || (isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc), IndexLinear)) - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) - bc′ = @static if VERSION >= v"1.10-" - Broadcasted(bcstyle, bcf, bcargs, bcaxes) - else - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) - end - + function (ctx, dest, bc, nelem) i = 1 while i <= nelem I = @linearidx(dest, i) - @inbounds dest[I] = bc′[I] + @inbounds dest[I] = bc[I] i += 1 end return end else - function (ctx, dest, nelem, bcstyle, bcf, bcaxes, bcargs...) - bc′ = @static if VERSION >= v"1.10-" - Broadcasted(bcstyle, bcf, bcargs, bcaxes) - else - Broadcasted{bcstyle}(bcf, bcargs, bcaxes) - end - + function (ctx, dest, bc, nelem) i = 0 while i < nelem i += 1 I = @cartesianidx(dest, i) - @inbounds dest[I] = bc′[I] + @inbounds dest[I] = bc[I] end return end @@ -95,13 +75,11 @@ end elements = length(dest) elements_per_thread = typemax(Int) - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, 1, - bcstyle, bc.f, bc.axes, bc.args...; + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc, 1; elements, elements_per_thread) config = launch_configuration(backend(dest), heuristic; elements, elements_per_thread) - gpu_call(broadcast_kernel, dest, config.elements_per_thread::Int, - bcstyle, bc.f, bc.axes, bc.args...; + gpu_call(broadcast_kernel, dest, bc, config.elements_per_thread; threads=config.threads, blocks=config.blocks) if eltype(dest) <: BrokenBroadcast