Skip to content

Commit

Permalink
Remove special-casing of Ref in broadcast. (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Jan 4, 2024
1 parent 4278412 commit e033242
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 27 deletions.
15 changes: 3 additions & 12 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@ using Base.Broadcast

import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate

const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
Base.RefValue{<:AbstractGPUArray{T}}}

# Ref is special: it's not a real wrapper, so not part of Adapt,
# but it is commonly used to bypass broadcasting of an argument
# so we need to preserve its dimensionless properties.
BroadcastStyle(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} =
typeof(BroadcastStyle(AT))(Val(0))
backend(::Type{Base.RefValue{AT}}) where {AT<:AbstractGPUArray} = backend(AT)
# but make sure we don't dispatch to the optimized copy method that directly indexes
function Broadcast.copy(bc::Broadcasted{<:AbstractGPUArrayStyle{0}})
ElType = Broadcast.combine_eltypes(bc.f, bc.args)
Expand Down Expand Up @@ -41,7 +32,7 @@ end
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
end

@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict
@inline Base.copyto!(dest::AnyGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict

@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc)

Expand Down Expand Up @@ -77,7 +68,7 @@ end
allequal(x) = true
allequal(x, y, z...) = x == y && allequal(y, z...)

function Base.map(f, x::BroadcastGPUArray, xs::AbstractArray...)
function Base.map(f, x::AnyGPUArray, xs::AbstractArray...)
# if argument sizes match, their shape needs to be preserved
xs = (x, xs...)
if allequal(size.(xs)...)
Expand All @@ -96,7 +87,7 @@ function Base.map(f, x::BroadcastGPUArray, xs::AbstractArray...)
return map!(f, dest, xs...)
end

function Base.map!(f, dest::BroadcastGPUArray, xs::AbstractArray...)
function Base.map!(f, dest::AnyGPUArray, xs::AbstractArray...)
# custom broadcast, ignoring the container size mismatches
# (avoids the reshape + view that our mapreduce impl has to do)
indices = LinearIndices.((dest, xs...))
Expand Down
15 changes: 0 additions & 15 deletions test/testsuite/broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,6 @@ function broadcasting(AT, eltypes)
map(+, x, y)
end
end

@testset "Ref" begin
# as first arg, 0d broadcast
@test compare(x->getindex.(Ref(x), 1), AT, ET[0])

void_setindex!(args...) = (setindex!(args...); return)
@test compare(x->(void_setindex!.(Ref(x), ET(1)); x), AT, ET[0])

# regular broadcast
a = AT(rand(ET, 10))
b = AT(rand(ET, 10))
cpy(i,a,b) = (a[i] = b[i]; return)
cpy.(1:10, Ref(a), Ref(b))
@test Array(a) == Array(b)
end
end

@testset "stackoverflow in copy(::Broadcast)" begin
Expand Down

0 comments on commit e033242

Please sign in to comment.