diff --git a/docs/src/basics.md b/docs/src/basics.md index 944c8f7..eaf60ea 100644 --- a/docs/src/basics.md +++ b/docs/src/basics.md @@ -345,6 +345,23 @@ true ``` This one can also be done as `reinterpret(reshape, Tri{Int64}, M)`. +But what would be smarter in the general case is to do one splat, not many: + +```julia-repl +julia> Tri.(eachrow(M)...) +4-element Vector{Tri{Int64}}: + Tri{Int64}(1, 2, 3) + Tri{Int64}(4, 5, 6) + Tri{Int64}(7, 8, 9) + Tri{Int64}(10, 11, 12) + +julia> @btime Base.splat(tuple).(eachcol(m)) setup=(m=rand(4,100)); + 38.041 μs (1411 allocations: 48.33 KiB) + +julia> @btime tuple.(eachrow(m)...) setup=(m=rand(4,100)); + 824.256 ns (12 allocations: 4.06 KiB) +``` + ## Arrays of functions Besides arrays of numbers (and arrays of arrays) you can also broadcast an array of functions, diff --git a/src/macro.jl b/src/macro.jl index 249956c..f87ab09 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -520,13 +520,6 @@ function readycast(ex, target, store::NamedTuple, call::CallInfo) # and arrays of functions, using apply: @capture(ex, funs_[ijk__](args__) ) && return :( Core._apply($funs[$(ijk...)], $(args...) ) ) - # splats - @capture(ex, fun_(pre__, arg_...)) && containsindexing(arg) && begin - @gensym splat ys - xs = [gensym(Symbol(:x, i)) for i in 1:length(pre)] - push!(store.main, :( local $splat($(xs...), $ys) = $fun($(xs...), $ys...) )) - return :( $splat($(pre...), $arg) ) - end # Apart from those, readycast acts only on lone tensors: @capture(ex, A_[ijk__]) || return ex @@ -627,11 +620,13 @@ end """ recursemacro(@reduce sum(i) A[i,j]) -> G[j] -Walks itself over RHS to look for `@reduce ...`, and replace with result, +Walks itself over RHS, originally to look for `@reduce ...`, and replace with result, pushing calculation steps into store. -Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`. -And to handle naked indices, `i` => `axes(M,1)[i]` but not exactly like that. +Starts from the outside and works in, which makes it useful for other things: +* Handle naked indices, `i` => `axes(M,1)[i]` but not exactly like that, stopping before this sees `A[i]`. +* Catch splats so that `f(M[:,c]...)` can become `f.(eachrow(M)...)` not `(splat(f)).(eachcol(M))`. +* Tidy all indices, including e.g. `fun(M[:,j], N[j]).same[i']`. """ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo) @@ -658,6 +653,36 @@ function recursemacro(ex::Expr, canon, store::NamedTuple, call::CallInfo) ex = scalar ? :($name) : :($name[$(ind...)]) end + # Handle splatted slices -- walking from inside outwards would slice the wrong way. + if @capture(ex, fun_(args__)) && any(a -> @capture(a, (A_[ijk__]...)), args) && any(iscolon, ijk) + newargs = map(args) do arg + if @capture(arg, (A_[ijk__]...)) && any(iscolon, ijk) + indpost = filter(!iscolon, ijk) + if indexin(indpost, canon) == 1:length(indpost) + Aperm = A + revcode = map(i -> iscolon(i) ? :* : :(:), ijk) + else + perm = indexin(canon, ijk) + while isnothing(last(perm)) # trim nothings off end + pop!(perm) + end + indpost = canon[1:length(perm)] + revcode = vcat(map(_ -> :*, perm), fill(:(:), count(iscolon, ijk))) + for (d,i) in enumerate(ijk) # append positions of colons + iscolon(i) && push!(perm, d) + end + Aperm = :( TensorCast.transmute($A, $(Tuple(perm))) ) + end + sliced = :( TensorCast.sliceview($Aperm, ($(revcode...),)) ) + sym = maybepush(sliced, store) + :(($sym[$(indpost...)])...) + else + recursemacro(arg, canon, store, call) + end + end + return :( $fun($(newargs...)) ) + end + # Tidy up indices, A[i,j][k] will be hit on different rounds... if @capture(ex, A_[ijk__]) if !(A isa Symbol) # this check allows some tests which have c[c] etc.