From b53148e6e0c4b791272eaa03242e67a6e66b67cb Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 21 Feb 2022 01:38:19 +0800 Subject: [PATCH 1/3] Use style to dispatch Broadcast --- src/host/broadcast.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 661e2dcd..7b75c9c2 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -47,7 +47,15 @@ end copyto!(similar(bc, ElType), bc) end -@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) +@inline function materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle} + return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) +end + +@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict + +@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc) + +@inline function _copyto!(dest::AbstractArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest bc′ = Broadcast.preprocess(dest, bc) @@ -72,12 +80,6 @@ end return dest end -# Base defines this method as a performance optimization, but we don't know how to do -# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback -@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) = - copyto!(dest, convert(Broadcasted{Nothing}, bc)) - - ## map allequal(x) = true From bee0d4f85224de2c6bab0088c8ad54d76f2db8a9 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 21 Feb 2022 16:39:44 +0800 Subject: [PATCH 2/3] Make sure dispatch with `dest`'s style work. --- src/host/broadcast.jl | 4 ++-- test/testsuite/broadcasting.jl | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 7b75c9c2..7a037985 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -4,7 +4,7 @@ export AbstractGPUArrayStyle using Base.Broadcast -import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle +import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate const BroadcastGPUArray{T} = Union{AnyGPUArray{T}, Base.RefValue{<:AbstractGPUArray{T}}} @@ -47,7 +47,7 @@ end copyto!(similar(bc, ElType), bc) end -@inline function materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle} +@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle} return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) end diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index 939af21b..fbb02683 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -1,6 +1,7 @@ @testsuite "broadcasting" (AT, eltypes)->begin broadcasting(AT, eltypes) vec3(AT, eltypes) + unknown_wrapper(AT, eltypes) @testset "type instabilities" begin f(x) = x ? 1.0 : 0 @@ -205,3 +206,22 @@ function vec3(AT, eltypes) @test all(map((a,b)-> all((1,2,3) .≈ (1,2,3)), Array(res2), res2c)) end end + +struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} + data::P +end +Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...] +Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...) +Base.size(A::WrapArray) = size(A.data) +Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P) +function unknown_wrapper(AT, eltypes) + @views for ET in eltypes + A = AT(randn(ET, 10, 10)) + WA = WrapArray(A) + @test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A) + @test Array(WA .+ A[:,1]) == Array(A .+ A[:,1]) + @test Array(WA .+ A[1,:]) == Array(A .+ A[1,:]) + WA .= ET(1) # test for dispatch with dest's BroadcastStyle. + @test all(isequal(ET(1)), Array(A)) + end +end \ No newline at end of file From 6bf63266131937835227a0d75abae4b03981d7bb Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 21 Feb 2022 18:23:11 +0800 Subject: [PATCH 3/3] Test fix and add more comments. --- test/testsuite/broadcasting.jl | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index fbb02683..5eda7ae5 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -207,21 +207,33 @@ function vec3(AT, eltypes) end end +# A help struct to test style-based broadcast dispatch with unknown array wrapper. +# `WrapArray(A)` behaves like `A` during broadcast. But its not a `BroadcastGPUArray`. struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} data::P end Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...] Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...) Base.size(A::WrapArray) = size(A.data) +# For kernal support +Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data)) +# For broadcast support +GPUArrays.backend(::Type{WrapArray{T,N,P}}) where {T,N,P} = GPUArrays.backend(P) Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P) + function unknown_wrapper(AT, eltypes) - @views for ET in eltypes - A = AT(randn(ET, 10, 10)) - WA = WrapArray(A) - @test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A) - @test Array(WA .+ A[:,1]) == Array(A .+ A[:,1]) - @test Array(WA .+ A[1,:]) == Array(A .+ A[1,:]) - WA .= ET(1) # test for dispatch with dest's BroadcastStyle. - @test all(isequal(ET(1)), Array(A)) + for ET in eltypes + @views @testset "unknown wrapper $ET" begin + A = AT(rand(ET, 10, 10)) + WA = WrapArray(A) + # test for dispatch with src's BroadcastStyle. + @test Array(WA .+ ET(1)) == Array(A .+ ET(1)) + @test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A) + @test Array(WA .+ A[:,1]) == Array(A .+ A[:,1]) + @test Array(WA .+ A[1,:]) == Array(A .+ A[1,:]) + # test for dispatch with dest's BroadcastStyle. + WA .= ET(1) + @test all(isequal(ET(1)), Array(A)) + end end -end \ No newline at end of file +end