Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom stacking for StaticArrays #564

Merged
merged 17 commits into from
Oct 10, 2024
Merged

Custom stacking for StaticArrays #564

merged 17 commits into from
Oct 10, 2024

Conversation

gdalle
Copy link
Member

@gdalle gdalle commented Oct 10, 2024

Partial answer to #563

Related:

Versions

  • Bump DI to v0.6.10

DI core

  • Customize the stacking functions used to turn tuples of arrays t into Jacobian/Hessian blocks

DI extensions

StaticArrays (new extension):

@codecov-commenter
Copy link

codecov-commenter commented Oct 10, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 98.00%. Comparing base (3698dbe) to head (71e1f45).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #564   +/-   ##
=======================================
  Coverage   98.00%   98.00%           
=======================================
  Files         106      108    +2     
  Lines        4808     4812    +4     
=======================================
+ Hits         4712     4716    +4     
  Misses         96       96           
Flag Coverage Δ
DI 98.66% <100.00%> (+<0.01%) ⬆️
DIT 96.68% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mcabbott
Copy link
Member

This PR adds an extension in order to have these two paths,

function DI.stack_vec_col(t::NTuple{B,<:SArray}) where {B}
     return hcat(map(vec, t)...)
 end
 
 stack_vec_col(t::NTuple) = stack(vec, t; dims=2)

Is this clearly better than just always using hcat? The splat would be terrible with long vectors but you only support a tuple anyway, so is should be free. For example:

julia> tm = ntuple(i -> fill(i,2,3), 10);

julia> @btime stack(vec, $tm; dims=1);
  142.494 ns (12 allocations: 912 bytes)

julia> @btime hcat(map(vec, $tm)...);
  101.484 ns (12 allocations: 912 bytes)

@gdalle
Copy link
Member Author

gdalle commented Oct 10, 2024

It's clearly better in the case of static arrays:

julia> using StaticArrays, BenchmarkTools

julia> ts = ntuple(i -> @SMatrix(ones(2,3)), 10);

julia> @btime stack(vec, $ts; dims=2);
  311.713 ns (1 allocation: 544 bytes)

julia> @btime hcat(map(vec, $ts)...);
  7.246 ns (0 allocations: 0 bytes)

@gdalle
Copy link
Member Author

gdalle commented Oct 10, 2024

Sorry I read your comment the wrong way. I did some more thorough benchmarks in this issue and stack does come out on top for usual arrays it seems?

@mcabbott
Copy link
Member

Wow that's quite hard to decode. (Probably I should have used dims=2 above, slightly faster codepath, same as no dims). But indeed stack is faster than hcat on that example:

julia> tm = ntuple(i -> rand(100, 100), 10);

julia> res1 = @btime stack(vec, $tm);  # called "bad stack, function" at link
  14.250 μs (13 allocations: 781.64 KiB)

julia> res2 = @btime stack(vec, $tm, dims=2);  # version with dims as in PR 
  14.250 μs (13 allocations: 781.64 KiB)

julia> res3 = @btime hcat(map(vec, $tm)...);  # called "good stack, function" at link
  64.416 μs (13 allocations: 781.64 KiB)

julia> res1 == res2 == res3
true

Whether that's true in general I don't know, perhaps I'm somewhat surprised. And if it is, whether it's worth the complexity is your call.

Note that vec(::Matrix) allocates. If you care a lot about this operation you might also consider lazier reshape:

julia> myvec(x) = Base.ReshapedArray(x, (length(x),), ());  # not sure 3rd argument is optimal!

julia> res4 = @btime hcat(map(myvec, $tm)...);  # faster
  38.792 μs (3 allocations: 781.33 KiB)

julia> @btime stack(myvec, $tm);  # slower
  22.084 μs (3 allocations: 781.33 KiB)
  
julia> res1 == res4
true

@gdalle
Copy link
Member Author

gdalle commented Oct 10, 2024

Whether that's true in general I don't know, perhaps I'm somewhat surprised. And if it is, whether it's worth the complexity is your call.

In the current state of things, stack is used everywhere. This very short PR switches to hcat only for SArrays, which is definitely better. There's a question of whether we should switch everywhere, but I'll put it on the back burner for now since the answer is less obvious from the benchmarks.

By the way, any ideas on how to implement stack(...; dims=1) with cat operations? For columns it was easy because hcat turns vectors into matrices, but vcat just makes longer vectors. Would you transpose first, or do yet another trick?

Note that vec(::Matrix) allocates. If you care a lot about this operation you might also consider lazier reshape:

Wow, I didn't know that. I naively thought that a Matrix was just a vector in a trench coat, hence this would be free.
Any reason not to use reshape in your suggestion?

@gdalle gdalle merged commit 7607ec2 into main Oct 10, 2024
43 checks passed
@gdalle gdalle deleted the gd/mapred branch October 10, 2024 16:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants