From 6732699e235fa9e3a2d001de11b963c0a407f214 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 21 Feb 2022 01:38:19 +0800 Subject: [PATCH] Update broadcast.jl --- src/host/broadcast.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 661e2dcd..7b75c9c2 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -47,7 +47,15 @@ end copyto!(similar(bc, ElType), bc) end -@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) +@inline function materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle} + return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) +end + +@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict + +@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc) + +@inline function _copyto!(dest::AbstractArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest bc′ = Broadcast.preprocess(dest, bc) @@ -72,12 +80,6 @@ end return dest end -# Base defines this method as a performance optimization, but we don't know how to do -# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback -@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) = - copyto!(dest, convert(Broadcasted{Nothing}, bc)) - - ## map allequal(x) = true