-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make StructArrayStyle
track inputs dimension
#211
Conversation
src/structarray.jl
Outdated
StructArrayStyle{S,M}(::Val{N}) where {M,S<:AbstractArrayStyle{M},N} = | ||
StructArrayStyle{typeof(S(Val(N))),N}() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this maybe be written using the function StructArrayStyle(...)
syntax and be split over a couple of lines, for the sake of readability?
src/structarray.jl
Outdated
@@ -445,7 +445,11 @@ end | |||
# broadcast | |||
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle | |||
|
|||
struct StructArrayStyle{Style} <: AbstractArrayStyle{Any} end | |||
struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end | |||
# If `S` also track input's dimensionality, we'd better also update it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confess I'm not fully familiar with the broadcast machinery, but I imagine this comment could be either expanded or removed, as it is it is a bit cryptic (it only makes sense if one knows the diff from this PR).
Thanks for the PR, this is very helpful! I've left some minor comments about style. |
The style confilct is resolved with c7a4dc8 julia> using StaticArrays, StructArrays
julia> a = @SArray (randn(3,3))
3×3 SMatrix{3, 3, Float64, 9} with indices SOneTo(3)×SOneTo(3):
-2.35452 -0.492535 0.401572
-0.133826 1.66839 2.41737
1.08633 -1.31099 -0.0771547
julia> b = StructArray([1;;1;;2+im])
1×3 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Complex{Int64}:
1+0im 1+0im 2+1im
julia> a .+ b
3×3 StructArray(::Matrix{Float64}, ::Matrix{Float64}) with eltype ComplexF64: # before 3 x 3 Matrix{ComplexF64}
-1.35452+0.0im 0.507465+0.0im 2.40157+1.0im
0.866174+0.0im 2.66839+0.0im 4.41737+1.0im
2.08633+0.0im -0.310988+0.0im 1.92285+1.0im
julia> a .+ cu(b) # they have style confict
3×3 StructArray(::Matrix{Float64}, ::Matrix{Float64}) with eltype ComplexF64:
2.037+0.0im 3.35919+0.0im 3.90445+1.0im
2.25239+0.0im 0.468944+0.0im 2.1828+1.0im
-0.506096+0.0im 1.52528+0.0im 3.47241+1.0im
julia> collect(a) .+ cu(b) # This is still done on CPU!
3×3 StructArray(::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}) with eltype ComplexF64:
2.037+0.0im 3.35919+0.0im 3.90445+1.0im
2.25239+0.0im 0.468944+0.0im 2.1828+1.0im
-0.506096+0.0im 1.52528+0.0im 3.47241+1.0im I only test it with This PR is more "aggressive" than the original version in #136. |
Getting this to work with CUDA would certainly be great... Pinging @timholy as he worked on the initial implementation, especially because in the intermediate comments, there were attempts to create such a fallback. |
src/structarray.jl
Outdated
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle{S}}) where {S} | ||
return copyto!(dest, convert(Broadcasted{S}, bc)) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At present, our GPU broadcast is dispatched based on dest's type:
@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing})
Once we change it to something like:
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle})
Then the unwrapped style above could activate the kernal automaticly.
Of cource there's still something left to handle manually. (e.g. backend(dest)
) But I think this does help us reuse more code.
Also fixes #189. Related: #150. Potentially related: #94. CC: @lcw (involved in both the issue and PR mentioned). So, I do have some comments on the code itself, but I think first we should figure out exactly what we want here. If I understand correctly, the key problem is that when computing the overall broadcast style, one should figure out the overall broadcast style At that point the issue becomes how to make sure that the correct mechanism for building the correct destination type is triggered. Expecting things other than Things I don't understand:
If both work, I imagine we should test them (potentially using something like JLArray for the first one). If this is not enough to get broadcast with GPU arrays to work (and we need to rethink the dispatch on |
Yes, as our general
We need some change like JuliaGPU/GPUArrays.jl#295 to enable I'll remove the 2nd/3rd commit, and add test for unstable broadcast. |
Yes, I think a good action plan would be to merge now something much simpler here (essentially the first commit). Then,
Thanks, I think that's best (there were also some comments on the first commit that were addressed somewhere within the later commits). Those 2 commits can be a draft pull request to master once the bugfix is merged (so we can make sure that things work with the changes in GPUArrays). |
Merged in #212, thanks again for the PR! If you want, feel free to add back the other changes against current master. |
test added.
Close #185