Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference failure for broadcast implementations that use flatten (eg SparseArrays) #27988

Closed
marius311 opened this issue Jul 8, 2018 · 20 comments · Fixed by #43322
Closed

inference failure for broadcast implementations that use flatten (eg SparseArrays) #27988

marius311 opened this issue Jul 8, 2018 · 20 comments · Fixed by #43322
Labels
broadcast Applying a function over a collection performance Must go faster

Comments

@marius311
Copy link
Contributor

marius311 commented Jul 8, 2018

I've been using the new broadcasting API to implement broadcasting for some custom types (which btw, the new API is great and massively simplified my code, thanks!). I did however run into the following surprising inference failure which I've reduced down to the following code:

import Base.Broadcast: BroadcastStyle, materialize, broadcastable
using Base.Broadcast: Broadcasted, Style, flatten
using Test

struct NumWrapper{T}
    data::T
end

broadcastable(n::NumWrapper) = n
BroadcastStyle(::Type{N}) where {N<:NumWrapper} = Style{N}()
broadcast_data(n::NumWrapper) = (n.data,)
function materialize(bc::Broadcasted{Style{N}}) where {N<:NumWrapper}
    flat_bc = flatten(bc)
    N(broadcast.(flat_bc.f, broadcast_data.(flat_bc.args)...)...)
end

foo(a,b) = @. a * a + (b * a) * b
bar(a,b) = @. a * a + b * a * b

@inferred foo(NumWrapper(1), NumWrapper(2)) #inferred as Any
@inferred bar(NumWrapper(1), NumWrapper(2)) #inferred correctly

As you can see, each NumWrapper just holds a number, and broadcasting over e.g. a::NumWrapper .+ b::NumWrapper becomes NumWrapper(broadcast.(+,(a.data,),(b.data,))...) = NumWrapper((a.data .+ b.data,)...). Note I need the .data wrapped in a tuple in my real code because in the real case NumWrapper has multiple fields; this is also crucial to triggering the bug, although in the simple example above it probably seems unnecessary. In any case, you see that the seemingly unimportant addition of the parenthesis around (b*a) spoils type stability. This is on commit 656d587.

Beyond a possible solution, I'm also curious if there's a workaround or a better way to code this up, I can't say I'm 100% sure I've used the new API as intended (in particular, I've used flatten and the docs make it seem like this shouldn't usually be necessary). In any case, hope this helps!

EDIT: Fixed a small mistake in the text describing what a::NumWrapper .+ b::NumWrapper became.

@StefanKarpinski
Copy link
Member

Excellent bug report! @Keno, would be great if you can look at the inference failure. @mbauman, can you give advice on the broadcast API usage?

@timholy
Copy link
Member

timholy commented Jul 8, 2018

Other than to say that I suspect you're working harder than you need to, this isn't quite enough to go on. What are the rules for indexing a NumWrapper? What rules of arithmetic do they follow? The main design goal of the new API for extending broadcasting (#23939) was to be "orthogonal" to both indexing and arithmetic, so that you only have to specify things that are specific to broadcasting (like how the indexing should work when combined with other containers).

So, you probably need some indexing rules, maybe like those in number.jl. And you should make sure that arithmetic works on its own:

julia> a, b = NumWrapper(1), NumWrapper(2)
(NumWrapper{Int64}(1), NumWrapper{Int64}(2))

julia> a*a + (b*a)*b
ERROR: MethodError: no method matching *(::NumWrapper{Int64}, ::NumWrapper{Int64})
Closest candidates are:
  *(::Any, ::Any, ::Any, ::Any...) at operators.jl:502
Stacktrace:
 [1] top-level scope at none:0

The absence of these rules may be why you have non-inferrability.

If they basically act like scalars, then it's just possible that you might need only one definition to support broadcasting:

Base.BroadcastStyle(::Type{<:NumWrapper}) = Base.Broadcast.DefaultArrayStyle{0}()

That declares that they act like 0-dimensional arrays, and I think the normal rules will work to preserve the "eltype." If you do have to give them a specific style, then you're going to have to define binary rules to specify how they combine with other containers.

@marius311
Copy link
Contributor Author

marius311 commented Jul 8, 2018

Thanks, yea I think you're right, sorry, the above example has too much stripped out for you to be able to really give meaningful advice.

To say a bit more, the real purpose is to have a structure (what I called NumWrapper above and I'll call Field below) for which broadcast operations are "forwarded" down to the individual fields of the structure. So if you had, e.g.

struct Foo <: Field
    a::Matrix
    b::Matrix
end

struct Bar <: Field
    c::Matrix
end

then you would have basically the following equivalency,

x::Foo .+ y::Foo .+ z::Foo = Foo(x.a .+ y.a .+ z.a, x.b .+ y.b .+ z.b)

Note that this isn't quite what e.g. tuple broadcasting does, because that doesn't "foward" the broadcast to the individual elements, so it can be inefficient if the elements are things like large matrices as is my case.

Similarly, you would have,

x::Foo .+ w::Bar = Foo(x.a .+ w.c, x.b .+ w.c)

In the real case I do indeed have a set of binary broadcast rules, for example to decide that the above result is a Foo (and some for reasons beyond even what I showed here).

I also have a materialize function which is in fact identical to the one in my first post, i.e. it does a call to flatten then does a broadcasted broadcast call over the fields in the data structure. That's really the piece I'm unsure of is the right way to do this. Any advice on this much appreciated (but by no means expected! really the main purpose here was the bug report)

@timholy
Copy link
Member

timholy commented Jul 9, 2018

This helps. There's still some question in my mind about whether you should be expressing those operations as broadcasting .+ or just ordinary +. Does Field support indexing, i.e., can I ask for (x::Foo)[3,5]? Can I set an entry with x[3,5] = (1.7, 3.5) or similar? If the answer is "no" then this might not be broadcasting; you might be just defining your own arithmetic operators which you happen to be writing in broadcasted notation. But if they don't support indexing, then in my opinion this is an abuse of broadcasting. (It's fine to use broadcasting in the internal implementation of your arithmetic rules, but if Foo doesn't support indexing then summing two Foos isn't a broadcasting operation.)

If they do support indexing, then you might only need something along the lines of

# These act like two-dimensional arrays, and should take precedence over e.g., Array
struct FooStyle <: Broadcast.AbstractArrayStyle{2} end
struct BarStyle <: Broadcast.AbstractArrayStyle{2} end
Base.BroadcastStyle(::Type{<:Foo}) = FooStyle()
Base.BroadcastStyle(::Type{<:Bar}) = BarStyle()

# FooStyle "beats" BarStyle: Foo.+Foo->Foo, Foo.+Bar->Foo, Bar.+Bar->Bar
Base.BroadcastStyle(::FooStyle, ::BarStyle) = FooStyle()

# Teach Julia how to allocate the output container
Base.similar(bc::Broadcasted{FooStyle}, ::Type{ElType}) = Foo{ElType}(undef, size(bc))
Base.similar(bc::Broadcasted{BarStyle}, ::Type{ElType}) = Bar{ElType}(undef, size(bc))

Julia's internals should handle the rest automatically for you 😄.

@mbauman
Copy link
Member

mbauman commented Jul 9, 2018

I would agree with Tim on the broadcasting design, although I must confess that this is a rather clever abuse that allows forwarding of arbitrary functions. You may run into trouble if your type hits another that overrides broadcast styles, though.

As far as the inference failure, it's stemming from flatten — we recursively construct a set of closures to do the flattening, and that's where inference is losing the trail. It's a rather crazy computation, but the part that's failing is the relatively simpler flattening of argument lists in cat_nested. I did a bit of cursory poking at it, but throwing a few extra @inlines and ::Vararg{Any,N} specializations on it didn't seem to solve the underlying issue. Here's a simpler MWE:

julia> bc =  Broadcast.Broadcasted(+, (Broadcast.Broadcasted(*, (1, 2)), Broadcast.Broadcasted(*, (Broadcast.Broadcasted(*, (3, 4)), 5))));

julia> Broadcast.cat_nested(x->x.args, bc)
(1, 2, 3, 4, 5)

julia> @code_warntype Broadcast.cat_nested(x->x.args, bc)
Body::Tuple{Int64,Int64,Vararg{Any,N} where N}

@marius311
Copy link
Contributor Author

Thanks for the replies. So my Fields do represent abstract vectors, the vector components just being all the entries in the matrices...e.g. the vector representation of f::Foo would be [f.a[:]; f.b[:]] (the reason to store them as matrices instead of just actual vectors is that there's other operations I do on them that make more sense in terms of the separate matrices). This means they could have get/setindex defined and I could implement broadcasting the way @timholy suggested, but in my get/setindex function I would have to do some indexing arithmetic to map into the vector representation, and my guess was that that would make broadcasting slower. Instead I went about it in the "top-down" fashion above and forwarded the broadcast directly to the matrices, which I figured wouldn't incur the performance hit (this I can confirm on 0.6 where it currently works) and also felt more intuitive. But maybe there's a way to code the indexing arithmetic in a way without incurring a performance hit? I admit I have not tried.

@StefanKarpinski
Copy link
Member

StefanKarpinski commented Jul 10, 2018

I guess the question I have then is: is this a performance bug or not? It sounds like this may be a case where broadcasting is inappropriate but this is an optimization we want to work anyway?

@mbauman
Copy link
Member

mbauman commented Jul 12, 2018

This is a real performance issue, and it'll have an effect on any broadcasting implementation that uses flatten — this notably includes SparseArrays. Here's the minimal example:

using Base: tail
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(cat_nested(t[1], tail(t)), rest)
cat_nested(t::Tuple{}, tail) = tail
t = ((1, 2), ((3, 4), 5))
@code_warntype cat_nested(t, ())

This worked in Julia 0.6, but I imagine that it broke as part of @vtjnash's work in limiting non-terminating recursive inference. Is there a workaround here? Or could we support this? The goal is simply to flatten arbitrarily nested tuples — in this case, just return (1,2,3,4,5).

@mbauman mbauman added performance Must go faster broadcast Applying a function over a collection labels Jul 12, 2018
@mbauman
Copy link
Member

mbauman commented Jul 12, 2018

Ah, here's the workaround: don't be doubly-recursive within a single method:

using Base: tail
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
cat_nested(t::Tuple{}, tail::Tuple{}) = ()
t = ((1, 2), ((3, 4), 5))
@code_warntype cat_nested(t, ())

@mbauman
Copy link
Member

mbauman commented Jul 13, 2018

Ok, this is bizarre. I'm getting a state-dependent order-of-compilation difference in inference:

$ ./julia
               _
   _       _ _(_)_     |  A fresh approach to technical computing
  (_)     | (_) (_)    |  Documentation: https://docs.julialang.org
   _ _   _| |_  __ _   |  Type "?" for help, "]?" for Pkg help.
  | | | | | | |/ _` |  |
  | | |_| | | | (_| |  |  Version 0.7.0-beta.283 (2018-07-12 22:44 UTC)
 _/ |\__'_|_|_|\__'_|  |  Commit 98061abb5a* (0 days old master)
|__/                   |  x86_64-linux-gnu

julia> using Base: tail
       cat_nested(t::Tuple) = cat_nested(t, ())
       cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
       cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
       cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
       cat_nested(t::Tuple{}, tail::Tuple{}) = ()
       t = ((1, 2), ((3, 4), 5))
((1, 2), ((3, 4), 5))

julia> @code_warntype cat_nested(t, ())
Body::NTuple{5,Int64}
4 1 ─ %1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}            │╻       getindex
  │         getfield(%%t, 1)                                           │╻       tail
  │   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}          ││
  │   %4  = Base.getfield(%1, 1, true)::Int64                          ││╻       getindex
  │         getfield(%1, 1)                                            ││╻       tail
  │   %6  = getfield(%1, 2)::Int64                                     │││
  │   %7  = Base.getfield(%3, 1, true)::Tuple{Int64,Int64}             │││╻╷╷╷    cat_nested
  │         getfield(%3, 1)                                            ││││╻       cat_nested
  │   %9  = getfield(%3, 2)::Int64                                     │││││┃│      cat_nested
  │   %10 = Base.getfield(%7, 1, true)::Int64                          ││││││╻       cat_nested
  │         getfield(%7, 1)                                            │││││││╻       tail
  │   %12 = getfield(%7, 2)::Int64                                     ││││││││
  │   %13 = Core.tuple(%4, %6, %10, %12, %9)::NTuple{5,Int64}          ││
  └──       return %13                                                 │

julia> @code_warntype cat_nested(t)
Body::NTuple{5,Int64}
2 1 ─ %1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}          │╻╷       cat_nested
  │         getfield(%%t, 1)                                         ││╻        tail
  │   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}        │││
  │   %4  = Base.getfield(%1, 1, true)::Int64                        │││╻        getindex
  │         getfield(%1, 1)                                          │││╻        tail
  │   %6  = getfield(%1, 2)::Int64                                   ││││
  │   %7  = Base.getfield(%3, 1, true)::Tuple{Int64,Int64}           ││││╻╷╷╷     cat_nested
  │         getfield(%3, 1)                                          │││││╻        cat_nested
  │   %9  = getfield(%3, 2)::Int64                                   ││││││┃│       cat_nested
  │   %10 = Base.getfield(%7, 1, true)::Int64                        │││││││╻        cat_nested
  │         getfield(%7, 1)                                          ││││││││╻        tail
  │   %12 = getfield(%7, 2)::Int64                                   │││││││││
  │   %13 = Core.tuple(%4, %6, %10, %12, %9)::NTuple{5,Int64}        │││
  └──       return %13                                               │

But if I restart my session and call those two methods in a different order:

$ ./julia -q
julia> using Base: tail
       cat_nested(t::Tuple) = cat_nested(t, ())
       cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
       cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
       cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
       cat_nested(t::Tuple{}, tail::Tuple{}) = ()
       t = ((1, 2), ((3, 4), 5))
((1, 2), ((3, 4), 5))

julia> @code_warntype cat_nested(t)
Body::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
2 1%1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}                     │╻╷  cat_nested
  │         getfield(%%t, 1)                                                    ││╻   tail
  │   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}                   │││
  │   %4  = Core.tuple(%3)::Tuple{Tuple{Tuple{Int64,Int64},Int64}}              ││
  │   %5  = Base.getfield(%1, 1, true)::Int64                                   │││╻   getindex
  │   %6  = Core.tuple(%5)::Tuple{Int64}                                        │││
  │         getfield(%1, 1)                                                     │││╻   tail
  │   %8  = getfield(%1, 2)::Int64                                              ││││
  │   %9  = Core.tuple(%8)::Tuple{Int64}                                        ││││
  │   %10 = invoke Main.cat_nested(%9::Tuple{Int64}, %4::Tuple{Tuple{Tuple{Int64,Int64},Int64}})::Tuple{Int64,Int64,Int64,Vararg{Any,N} where N}%11 = Core._apply(Core.tuple, %6, %10)::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
  └──       return %11                                                          │

julia> @code_warntype cat_nested(t, ())
Body::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
4 1%1  = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64}                       │╻  getindex
  │         getfield(%%t, 1)                                                      │╻  tail
  │   %3  = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64}                     ││
  │   %4  = Core.tuple(%3)::Tuple{Tuple{Tuple{Int64,Int64},Int64}}                │
  │   %5  = Base.getfield(%1, 1, true)::Int64                                     ││╻  getindex
  │   %6  = Core.tuple(%5)::Tuple{Int64}                                          ││
  │         getfield(%1, 1)                                                       ││╻  tail
  │   %8  = getfield(%1, 2)::Int64                                                │││
  │   %9  = Core.tuple(%8)::Tuple{Int64}                                          │││
  │   %10 = invoke Main.cat_nested(%9::Tuple{Int64}, %4::Tuple{Tuple{Tuple{Int64,Int64},Int64}})::Tuple{Int64,Int64,Int64,Vararg{Any,N} where N}%11 = Core._apply(Core.tuple, %6, %10)::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
  └──       return %11

@gasagna
Copy link
Contributor

gasagna commented Jul 15, 2018

Consider this MWE, a similar case to the OP:

using BenchmarkTools

struct AugmentedState{X, Q}
    x::X
    q::Q
end

_state(x::AugmentedState) = x.x
_quad(x::AugmentedState) = x.q
_state(x) = x
_quad(x) = x


if VERSION > v"0.6.5"
    using Printf
    @inline Broadcast.broadcastable(x::AugmentedState) = x
    Base.ndims(::Type{<:AugmentedState}) = 0

    @inline function Broadcast.materialize!(dest::AugmentedState, bc::Broadcast.Broadcasted)
        bcf = Broadcast.flatten(bc)
        Broadcast.broadcast!(bcf.f, _state(dest), map(_state, bcf.args)...)
        Broadcast.broadcast!(bcf.f, _quad(dest),  map(_quad,  bcf.args)...)
        return dest
    end
else
    @inline function Base.Broadcast.broadcast!(f, dest::AugmentedState, args...)
        broadcast!(f, _state(dest), map(_state, args)...)
        broadcast!(f,  _quad(dest), map(_quad,  args)...)
        return dest
    end
end

a = AugmentedState(rand(100), rand(100))
b = AugmentedState(rand(100), rand(100))
c = AugmentedState(rand(100), rand(100))
d = AugmentedState(rand(100), rand(100))

bar(a, b, c, d) = (a .= a .+ 2.0.*b .+ 5.0.*c .- 7.0.*d; a)

t1 = @belapsed $bar($a, $b, $c, $d)
t2 = @belapsed (bar($a.x, $b.x, $c.x, $d.x); bar($a.q, $b.q, $c.q, $d.q))

@printf "AugumentedState = %7d ns\n" t1*10^9
@printf "Arrays          = %7d ns\n" t2*10^9
@printf "Ratio           = %7d x  \n"  t1/t2

Using v0.6.4 I get:

AugumenteState =     615 ns
Arrays         =     616 ns
Ratio          =       1 x  

while on latest master I get:

AugumentedState =   13962 ns
Arrays          =     256 ns
Ratio           =      55 x  

Checking out #28111 does not seem to make a large difference for this case, though.

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Jul 20, 2018

Closed by accident? @gasagna showed #28111 likely isn't the solution (at least for his related problem).

Edit: I see, that's considered a different issue #28126

@gasagna
Copy link
Contributor

gasagna commented Jul 21, 2018

Yes, using latest master, the performance of the code in my previous comment does not improve and has high allocations.

@marius311
Copy link
Contributor Author

Hi @mbauman and others, so #28111 definitely fixed the code in my original post, but the following still does not work :

using Test, SparseArrays
foo(a) = @. a*a*a*a*a*a*a
bar(a) = @. a*(a*a)*a*a*a*a
@inferred foo(spzeros(1)) # inferred
@inferred bar(spzeros(1)) # not inferred

(since it was pointed out above that SparseArrays use flatten, I was able to reduce this further to not bother with the NumWrapper thing).

Oddly enough (and hopefully a hint), the parenthesis around literally any other pair of 2 as is fine.

I think this is different than #28126 / @gasagna's example above because there there is no inference failure, while here there is.

@marius311
Copy link
Contributor Author

marius311 commented Aug 5, 2018

Messing around some more, I think I've reduced it further to what might be the same "order-of-compilation difference in inference" mentioned in #27988 (comment) above, although that was before #28111, here I'm after (specifically 0.7rc2)

Here's a MWE. This will infer correctly:

using Test
using Base.Broadcast: cat_nested, Broadcasted
@inferred cat_nested(1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)
@inferred cat_nested(Broadcasted(*, (1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)))

and in a fresh session, this will fail (this is the same thing as above with the 3rd line commented):

using Test
using Base.Broadcast: cat_nested, Broadcasted
#@inferred cat_nested(1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)
@inferred cat_nested(Broadcasted(*, (1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)))

with

ERROR: return type NTuple{7,Int64} does not match inferred return type Tuple{Int64,Int64,Int64,Int64,Int64,Vararg{Int64,N} where N}

That failure in the second case is exactly what's causing the failure for a*(a*a)*a*a*a*a above.

@marius311
Copy link
Contributor Author

Curious if anyone has made progress on this since the dust has settled after JuliaCon?

After more reading / digging, I venture to guess that what is happening here is exactly what is described in the "Independence of the cached result" section from https://juliacomputing.com/blog/2017/05/15/inference-converage2.html which "remains an unsolved problem". This is probably clear to the people that might be able to fix this anyway (if I'm right that that's what it is), but I figure worth mentioning.

@marius311 marius311 changed the title custom broadcasting inference failure on 0.7 inference failure for broadcast implementations that use flatten (eg SparseArrays) Sep 12, 2018
@jlchan
Copy link

jlchan commented Feb 3, 2021

Possibly related MWE - no type instability, but some extra allocations show up depending on order of operations for types which use Broadcast.flatten

foo1(x,y) = @. x*(x + y) + y*y
foo2(x,y) = @. (x + y)*x + y*y
foo3(x,y) = @. x*x + x*y + y*y

Using StaticArrays, foo1 allocates, foo2, foo3 do not.

using StaticArrays
x = @SVector [1.0]
y = @SVector [2.0]
@btime foo1($x,$y); # 478.870 ns (25 allocations: 480 bytes)
@btime foo2($x,$y); # 2.949 ns (0 allocations: 0 bytes)

Using SparseArrays

using SparseArrays
x = sparsevec([1.0])
y = sparsevec([2.0])

only foo3 avoids extra allocations

julia> @btime foo1($x,$y);
  1.111 μs (47 allocations: 1.09 KiB)
julia> @btime foo2($x,$y);
  1.124 μs (29 allocations: 1.20 KiB)
julia> @btime foo3($x,$y);
  124.881 ns (2 allocations: 192 bytes)

@ChrisRackauckas
Copy link
Member

Does that still occur if you use DiffEqBase.@..?

@jlchan
Copy link

jlchan commented Apr 9, 2021

Yeah: for StaticArrays arguments

julia> foo1(x,y) = DiffEqBase.@.. x*(x + y) + y*y
foo1 (generic function with 1 method)

julia> foo2(x,y) = DiffEqBase.@.. (x + y)*x + y*y
foo2 (generic function with 1 method)

julia> @btime foo1($x,$y); # 478.870 ns (25 allocations: 480 bytes)
  477.878 ns (25 allocations: 480 bytes)

julia> @btime foo2($x,$y); # 2.949 ns (0 allocations: 0 bytes)
  1.377 ns (0 allocations: 0 bytes)

@chriselrod
Copy link
Contributor

Inference seems to be failing. Saw a lot of ::Vararg{Float64,N} where Ns.
Shot in the dark: inference is bailing on recursion in the head/tail makeargs?

github-merge-queue bot pushed a commit that referenced this issue Jul 15, 2023
…d and inlined) (#43322)

A follow up attemp to fix #27988. (close #47493 close #50554)
Examples:
```julia
julia> using LazyArrays
julia> bc = @~ @. 1*(1 + 1) + 1*1;
julia> bc2 = @~ 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3;
```
On master:
<details><summary> click for details </summary>
<p>

```julia
julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1)
CodeInfo(
1 ─ %1  = Core.getfield(args, 1)::Int64
│   %2  = Core.getfield(args, 2)::Int64
│   %3  = Core.getfield(args, 3)::Int64
│   %4  = Core.getfield(args, 4)::Int64
│   %5  = Core.getfield(args, 5)::Int64
│   %6  = invoke Base.Broadcast.var"#13#14"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(+)}}(Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(+)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), +))(%1::Int64, %2::Int64, %3::Vararg{Int64}, %4, %5)::Tuple{Int64, Int64, Vararg{Int64}}
│   %7  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %6)::Tuple{Int64, Int64}
│   %8  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %6)::Tuple{Vararg{Int64}}
│   %9  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#9#11", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(*)}(Base.Broadcast.var"#9#11"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), *), %8)::Tuple{Int64}
│   %10 = Core.getfield(%7, 1)::Int64
│   %11 = Core.getfield(%7, 2)::Int64
│   %12 = Base.mul_int(%10, %11)::Int64
│   %13 = Core.getfield(%9, 1)::Int64
│   %14 = Base.add_int(%12, %13)::Int64
└──       return %14
) => Int64

julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3))
CodeInfo(
1 ─ %1  = Core.getfield(args, 1)::Int64
│   %2  = Core.getfield(args, 2)::Int64
│   %3  = Core.getfield(args, 3)::Int64
│   %4  = Core.getfield(args, 5)::Int64
│   %5  = Core.getfield(args, 7)::Int64
│   %6  = Core.getfield(args, 8)::Int64
│   %7  = Core.getfield(args, 10)::Int64
│   %8  = invoke Base.Broadcast.var"#13#14"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}}(Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"()))), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()))), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"()))), Base.literal_pow))(%3::Int64, ^::Function, %4::Vararg{Any}, $(QuoteNode(Val{2}())), %5, %6, ^, %7, $(QuoteNode(Val{3}())))::Tuple{Int64, Any, Vararg{Any}}
│   %9  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %8)::Tuple{Int64, Any}
│   %10 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %8)::Tuple
│   %11 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#15#17"(), %10)::Tuple
│   %12 = Core.getfield(%9, 1)::Int64
│   %13 = Core.getfield(%9, 2)::Any
│   %14 = (*)(%12, %13)::Any
│   %15 = Core.tuple(%14)::Tuple{Any}
│   %16 = Core._apply_iterate(Base.iterate, Core.tuple, %15, %11)::Tuple{Any, Vararg{Any}}
│   %17 = Base.mul_int(%1, %2)::Int64
│   %18 = Core.tuple(%17)::Tuple{Int64}
│   %19 = Core._apply_iterate(Base.iterate, Core.tuple, %18, %16)::Tuple{Int64, Any, Vararg{Any}}
│   %20 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %19)::Tuple{Int64, Any}
│   %21 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %19)::Tuple
│   %22 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(*)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), *), %21)::Tuple{Any, Vararg{Any}}
│   %23 = Core.getfield(%20, 1)::Int64
│   %24 = Core.getfield(%20, 2)::Any
│   %25 = (-)(%23, %24)::Any
│   %26 = Core.tuple(%25)::Tuple{Any}
│   %27 = Core._apply_iterate(Base.iterate, Core.tuple, %26, %22)::Tuple{Any, Any, Vararg{Any}}
│   %28 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %27)::Tuple{Any, Any}
│   %29 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %27)::Tuple
│   %30 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#9#11", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"#9#11"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"()))), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()))), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"()))), Base.literal_pow), %29)::Tuple{Any}
│   %31 = Core.getfield(%28, 1)::Any
│   %32 = Core.getfield(%28, 2)::Any
│   %33 = (+)(%31, %32)::Any
│   %34 = Core.getfield(%30, 1)::Any
│   %35 = (+)(%33, %34)::Any
└──       return %35
) => Any
```
</p>

</details>

On this PR
```julia
julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1)
CodeInfo(
1 ─ %1 = Core.getfield(args, 1)::Int64
│   %2 = Core.getfield(args, 2)::Int64
│   %3 = Core.getfield(args, 3)::Int64
│   %4 = Core.getfield(args, 4)::Int64
│   %5 = Core.getfield(args, 5)::Int64
│   %6 = Base.add_int(%2, %3)::Int64
│   %7 = Base.mul_int(%1, %6)::Int64
│   %8 = Base.mul_int(%4, %5)::Int64
│   %9 = Base.add_int(%7, %8)::Int64
└──      return %9
) => Int64

julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3))
CodeInfo(
1 ─ %1  = Core.getfield(args, 1)::Int64
│   %2  = Core.getfield(args, 2)::Int64
│   %3  = Core.getfield(args, 3)::Int64
│   %4  = Core.getfield(args, 5)::Int64
│   %5  = Core.getfield(args, 7)::Int64
│   %6  = Core.getfield(args, 8)::Int64
│   %7  = Core.getfield(args, 10)::Int64
│   %8  = Base.mul_int(%1, %2)::Int64
│   %9  = Base.mul_int(%4, %4)::Int64
│   %10 = Base.mul_int(%3, %9)::Int64
│   %11 = Base.sub_int(%8, %10)::Int64
│   %12 = Base.mul_int(%5, %6)::Int64
│   %13 = Base.add_int(%11, %12)::Int64
│   %14 = Base.mul_int(%7, %7)::Int64
│   %15 = Base.mul_int(%14, %7)::Int64
│   %16 = Base.add_int(%13, %15)::Int64
└──       return %16
) => Int64
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
broadcast Applying a function over a collection performance Must go faster
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants