Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize StructArray's broadcast. #215

Merged
merged 9 commits into from
Nov 30, 2022
Merged

Conversation

N5N3
Copy link
Contributor

@N5N3 N5N3 commented Feb 20, 2022

Seperated from #211
The final goal is to enable GPU broadcast. (close #150)

GPU related:

@N5N3

This comment was marked as outdated.

src/structarray.jl Outdated Show resolved Hide resolved
src/staticarrays_support.jl Outdated Show resolved Hide resolved
src/structarray.jl Outdated Show resolved Hide resolved
src/structarray.jl Show resolved Hide resolved
src/structarray.jl Show resolved Hide resolved
@piever
Copy link
Collaborator

piever commented Feb 21, 2022

This is great! I've left a few comments: some are just about coding style, other about checking that we don't do unnecessary allocations.

The example in the comment above is very nice. Btw, does cu(a) use non-scalar getindex, or is it just displaying the result that causes that?

@N5N3
Copy link
Contributor Author

N5N3 commented Feb 21, 2022

The example in the comment above is very nice. Btw, does cu(a) use non-scalar getindex, or is it just displaying the result that causes that?

Caused by displaying. If I add ;, the warning will disappear.

@N5N3 N5N3 force-pushed the style_conflict branch 3 times, most recently from c4e7a03 to a92bc41 Compare February 21, 2022 12:07
src/staticarrays_support.jl Outdated Show resolved Hide resolved
@mcabbott
Copy link

What's the status of this work?

Perhaps overlapping, what JuliaDiff/ChainRules.jl#644 would like is broadcasts like this:

julia> function _tuplecast(f::F, args...) where {F}
           bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
           return StructArrays.components(StructArray(bc))
       end;

julia> _tuplecast(tuple, [1,2,3], [3,4,5])
([1, 2, 3], [3, 4, 5])

julia> _tuplecast(tuple, cu([1,2,3]), cu([3,4,5]))
ERROR: Scalar indexing is disallowed.
...julia
  [9] iterate
    @ ./broadcast.jl:261 [inlined]
 [10] iterate(bc::Base.Broadcast.Broadcasted{JLArrays.JLArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(tuple), Tuple{JLArray{Int64, 1}, JLArray{Int64, 1}}})
    @ Base.Broadcast ./broadcast.jl:255
 [11] collect_structarray(itr::Base.Broadcast.Broadcasted{JLArrays.JLArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(tuple), Tuple{JLArray{Int64, 1}, JLArray{Int64, 1}}}; initializer::StructArrays.StructArrayInitializer{typeof(StructArrays.alwaysfalse), typeof(StructArrays.arrayof)})
    @ StructArrays ~/.julia/packages/StructArrays/xGn32/src/collect.jl:39

There is now a GPUArraysCore.jl if this helps.

@N5N3
Copy link
Contributor Author

N5N3 commented Jul 15, 2022

I haven't paid attention to our GPU ecosystem for a long time.
GPUArraysCore.jl should be a better place than Adapt.jl to hold GPUArrays.backend.
With that I think we can finish this PR and _tuplecast could become something like:

function _tuplecast(f::F, args...) where {F}
   bc = Broadcast.instantiate(Broadcast.broadcasted(f, args...))
   return StructArrays.components(copy(convert(Broadcasted{StructArrayStyle{BroadcastStyle(bc),ndims(bc)}}, bc)))
end

@N5N3
Copy link
Contributor Author

N5N3 commented Jul 27, 2022

Now the GPU patch is pending JuliaGPU/GPUArrays.jl#420.
Since we add a GPUArraysCore.jl dependency, < 1.6 compatibility will break after this PR.
The StaticArray related overload needs StaticArrays >= 1.4.2, with which we have:

julia> a = @SMatrix [float(i) for i in 1:4, j in 1:4];

julia> b = @SMatrix [float(j) for i in 1:4, j in 1:4];

julia> s = StructArray{ComplexF64}((a , b))
4×4 StructArray(::SMatrix{4, 4, Float64, 16}, ::SMatrix{4, 4, Float64, 16}) with eltype ComplexF64 with indices SOneTo(4)×SOneTo(4):
 1.0+1.0im  1.0+2.0im  1.0+3.0im  1.0+4.0im
 2.0+1.0im  2.0+2.0im  2.0+3.0im  2.0+4.0im
 3.0+1.0im  3.0+2.0im  3.0+3.0im  3.0+4.0im
 4.0+1.0im  4.0+2.0im  4.0+3.0im  4.0+4.0im

julia> abs.(s)
4×4 SMatrix{4, 4, Float64, 16} with indices SOneTo(4)×SOneTo(4):
 1.41421  2.23607  3.16228  4.12311
 2.23607  2.82843  3.60555  4.47214
 3.16228  3.60555  4.24264  5.0
 4.12311  4.47214  5.0      5.65685

julia> log.(s) # Yes, this return a StructArray now!
4×4 StructArray(::SMatrix{4, 4, Float64, 16}, ::SMatrix{4, 4, Float64, 16}) with eltype ComplexF64 with indices SOneTo(4)×SOneTo(4):
 0.346574+0.785398im  0.804719+1.10715im   1.15129+1.24905im   1.41661+1.32582im
 0.804719+0.463648im   1.03972+0.785398im  1.28247+0.982794im  1.49787+1.10715im
  1.15129+0.321751im   1.28247+0.588003im  1.44519+0.785398im  1.60944+0.927295im
  1.41661+0.244979im   1.49787+0.463648im  1.60944+0.643501im  1.73287+0.785398im

I think @mcabbott might be happy with this feature (As we might pass StaticArray into _tuplecast?)

@N5N3
Copy link
Contributor Author

N5N3 commented Aug 21, 2022

Now we drop dependency on StaticArrays.jl. Broadcast on GPU should be ready now.
Static broadcast still needs some code move (JuliaArrays/StaticArraysCore.jl#8 and JuliaArrays/StaticArraysCore.jl#5)

@N5N3 N5N3 force-pushed the style_conflict branch 2 times, most recently from 4743976 to 2644c24 Compare August 24, 2022 09:14
@N5N3 N5N3 closed this Aug 24, 2022
@N5N3 N5N3 reopened this Aug 24, 2022
@N5N3 N5N3 marked this pull request as ready for review August 24, 2022 11:18
@N5N3
Copy link
Contributor Author

N5N3 commented Aug 24, 2022

@piever Should be ready to review.

The only concern is SizedArray, the current implement would cause extra allocation as we have to create a temporary StaticArray if we don't want to copy a lot of code from StaticArrays. (I don't think we'll have a pubic api in StaticArraysCore that return the broadcasted tuple.)

@piever
Copy link
Collaborator

piever commented Sep 2, 2022

Thanks for pushing for all the downstream changes, I'll try to find time to review in the next few days.

The only concern is SizedArray, the current implement would cause extra allocation as we have to create a temporary StaticArray if we don't want to copy a lot of code from StaticArrays. (I don't think we'll have a pubic api in StaticArraysCore that return the broadcasted tuple.)

I confess I'm not sure I 100% I understand what happens there. When exactly does the extra allocation happen for SizedArray? I would imagine you only allocate once the output and modify it in-place (AFAIU, SizedArrays are mutable).

@N5N3
Copy link
Contributor Author

N5N3 commented Sep 4, 2022

Yes, SizedArrays are mutable, but StaticArrays's broadcast still generate a tuple containing all broadcasted result, then convert this tuple to a SizedArray.

At present, since we drop the dependency, we have to have this Tuple -> StaticArray -> Tuple transformation.
If this StaticArray isa S/MArray or FieldArray, our compiler should be able to optimize it away.
But that's not true for SizedArray.

Once we have more tools in StaticArraysCore.jl, we can try to implement static broadcast outself to try to avoid this problem.

@piever
Copy link
Collaborator

piever commented Oct 14, 2022

Sorry it took me a long time to review this!

It mostly looks good. The main worry is whether there is a performance hit in the general case due to the overload of Broadcast._axes here.

Performance issues on things that didn't work before (like the case with StaticArrays columns) are not a showstopper, but at least we must test that there is no performance regression in the case of plain Array columns.

@piever
Copy link
Collaborator

piever commented Oct 14, 2022

One last thing, we should also check the interaction of this with #249 (support for nested broadcast). It should be a simple rebase, as #249 is a tiny change.

@N5N3 N5N3 force-pushed the style_conflict branch 2 times, most recently from 44f18d5 to cd32b53 Compare October 15, 2022 02:52
@N5N3
Copy link
Contributor Author

N5N3 commented Oct 15, 2022

It mostly looks good. The main worry is whether there is a performance hit in the general case due to the overload of Broadcast._axes here.

This overloading is limited to StaticArray so I believe there no performance influence on plain Array.

One last thing, we should also check the interaction of this with #249 (support for nested broadcast). It should be a simple rebase, as #249 is a tiny change.

I have avoid possible nested StructArrayStyle in 47f122c with

combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle

And #249 has been pull in to test it.

N5N3 and others added 5 commits October 15, 2022 11:47
We first call broadcast from `StaticArrays` then split the output.
This should has no extra runtime overhead. But some type info might missing because the eltype change. I think there's no better ways as we don't want to depend on the full  `StaticArrays`.

We don't overloading `Size` and `similar_type` at present.
as they are only used for `broadcast`.
With this, we can move much less code to `StaticArraysCore`.

The only downside is that SizedArray would be allocated twice. That's not idea, but we can't do any better if we don't depend on StaticArray or copy a lot of code from there.
src/StructArrays.jl Outdated Show resolved Hide resolved
Project.toml Outdated Show resolved Hide resolved
@piever
Copy link
Collaborator

piever commented Nov 5, 2022

Sorry for the delay between reviews!

I confess some of the code is a bit over my head, but (other than the small comments above) it looks good to me!

1. Update Project.toml.
2. test `backend`'s inferability.

Co-Authored-By: Pietro Vertechi <[email protected]>
test/runtests.jl Outdated
Comment on lines 1164 to 1169
try
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
@test typeof(a .+ b .- c) == typeof(d)
catch
@test_throws MethodError a .+ b .- c
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This had escaped me before, but I'm wondering: could it be possible to be explicit here on which it is (correct result or method error) based on the input types?

Ideally one would want to explicitly test what one is getting, so I would suggest to remove the helper function _test_similar and just write something like

if s2 in (s, s′, s″) && (s1 in (s, s′, s″) || s3 in (s, s′, s″))
    # test method error
else
    # test correct result
end

in the loop body.

(I'm not sure whether that's the correct criterion.)

test/runtests.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
only test each dest style once.
test/runtests.jl Outdated Show resolved Hide resolved
@piever
Copy link
Collaborator

piever commented Nov 30, 2022

Perfect, thanks for bringing this over the finish line!

@piever piever merged commit 1afecf4 into JuliaArrays:master Nov 30, 2022
@N5N3 N5N3 deleted the style_conflict branch December 2, 2022 12:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support of GPU Broadcast with StructArray LHS?
3 participants