Skip to content

Commit

Permalink
Make sure dispatch with dest's style work.
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Feb 21, 2022
1 parent b53148e commit bee0d4f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export AbstractGPUArrayStyle

using Base.Broadcast

import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate

const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
Base.RefValue{<:AbstractGPUArray{T}}}
Expand Down Expand Up @@ -47,7 +47,7 @@ end
copyto!(similar(bc, ElType), bc)
end

@inline function materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
end

Expand Down
20 changes: 20 additions & 0 deletions test/testsuite/broadcasting.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testsuite "broadcasting" (AT, eltypes)->begin
broadcasting(AT, eltypes)
vec3(AT, eltypes)
unknown_wrapper(AT, eltypes)

@testset "type instabilities" begin
f(x) = x ? 1.0 : 0
Expand Down Expand Up @@ -205,3 +206,22 @@ function vec3(AT, eltypes)
@test all(map((a,b)-> all((1,2,3) .≈ (1,2,3)), Array(res2), res2c))
end
end

struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
data::P
end
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
Base.size(A::WrapArray) = size(A.data)
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
function unknown_wrapper(AT, eltypes)
@views for ET in eltypes
A = AT(randn(ET, 10, 10))
WA = WrapArray(A)
@test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A)
@test Array(WA .+ A[:,1]) == Array(A .+ A[:,1])
@test Array(WA .+ A[1,:]) == Array(A .+ A[1,:])
WA .= ET(1) # test for dispatch with dest's BroadcastStyle.
@test all(isequal(ET(1)), Array(A))
end
end

0 comments on commit bee0d4f

Please sign in to comment.