-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inference failure for broadcast implementations that use flatten
(eg SparseArrays)
#27988
Comments
Other than to say that I suspect you're working harder than you need to, this isn't quite enough to go on. What are the rules for indexing a So, you probably need some indexing rules, maybe like those in number.jl. And you should make sure that arithmetic works on its own: julia> a, b = NumWrapper(1), NumWrapper(2)
(NumWrapper{Int64}(1), NumWrapper{Int64}(2))
julia> a*a + (b*a)*b
ERROR: MethodError: no method matching *(::NumWrapper{Int64}, ::NumWrapper{Int64})
Closest candidates are:
*(::Any, ::Any, ::Any, ::Any...) at operators.jl:502
Stacktrace:
[1] top-level scope at none:0 The absence of these rules may be why you have non-inferrability. If they basically act like scalars, then it's just possible that you might need only one definition to support broadcasting: Base.BroadcastStyle(::Type{<:NumWrapper}) = Base.Broadcast.DefaultArrayStyle{0}() That declares that they act like 0-dimensional arrays, and I think the normal rules will work to preserve the "eltype." If you do have to give them a specific style, then you're going to have to define binary rules to specify how they combine with other containers. |
Thanks, yea I think you're right, sorry, the above example has too much stripped out for you to be able to really give meaningful advice. To say a bit more, the real purpose is to have a structure (what I called struct Foo <: Field
a::Matrix
b::Matrix
end
struct Bar <: Field
c::Matrix
end then you would have basically the following equivalency, x::Foo .+ y::Foo .+ z::Foo = Foo(x.a .+ y.a .+ z.a, x.b .+ y.b .+ z.b) Note that this isn't quite what e.g. tuple broadcasting does, because that doesn't "foward" the broadcast to the individual elements, so it can be inefficient if the elements are things like large matrices as is my case. Similarly, you would have, x::Foo .+ w::Bar = Foo(x.a .+ w.c, x.b .+ w.c) In the real case I do indeed have a set of binary broadcast rules, for example to decide that the above result is a I also have a |
This helps. There's still some question in my mind about whether you should be expressing those operations as broadcasting If they do support indexing, then you might only need something along the lines of # These act like two-dimensional arrays, and should take precedence over e.g., Array
struct FooStyle <: Broadcast.AbstractArrayStyle{2} end
struct BarStyle <: Broadcast.AbstractArrayStyle{2} end
Base.BroadcastStyle(::Type{<:Foo}) = FooStyle()
Base.BroadcastStyle(::Type{<:Bar}) = BarStyle()
# FooStyle "beats" BarStyle: Foo.+Foo->Foo, Foo.+Bar->Foo, Bar.+Bar->Bar
Base.BroadcastStyle(::FooStyle, ::BarStyle) = FooStyle()
# Teach Julia how to allocate the output container
Base.similar(bc::Broadcasted{FooStyle}, ::Type{ElType}) = Foo{ElType}(undef, size(bc))
Base.similar(bc::Broadcasted{BarStyle}, ::Type{ElType}) = Bar{ElType}(undef, size(bc)) Julia's internals should handle the rest automatically for you 😄. |
I would agree with Tim on the broadcasting design, although I must confess that this is a rather clever abuse that allows forwarding of arbitrary functions. You may run into trouble if your type hits another that overrides broadcast styles, though. As far as the inference failure, it's stemming from julia> bc = Broadcast.Broadcasted(+, (Broadcast.Broadcasted(*, (1, 2)), Broadcast.Broadcasted(*, (Broadcast.Broadcasted(*, (3, 4)), 5))));
julia> Broadcast.cat_nested(x->x.args, bc)
(1, 2, 3, 4, 5)
julia> @code_warntype Broadcast.cat_nested(x->x.args, bc)
Body::Tuple{Int64,Int64,Vararg{Any,N} where N}
… |
Thanks for the replies. So my |
I guess the question I have then is: is this a performance bug or not? It sounds like this may be a case where broadcasting is inappropriate but this is an optimization we want to work anyway? |
This is a real performance issue, and it'll have an effect on any broadcasting implementation that uses using Base: tail
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(cat_nested(t[1], tail(t)), rest)
cat_nested(t::Tuple{}, tail) = tail
t = ((1, 2), ((3, 4), 5))
@code_warntype cat_nested(t, ()) This worked in Julia 0.6, but I imagine that it broke as part of @vtjnash's work in limiting non-terminating recursive inference. Is there a workaround here? Or could we support this? The goal is simply to flatten arbitrarily nested tuples — in this case, just return |
Ah, here's the workaround: don't be doubly-recursive within a single method: using Base: tail
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
cat_nested(t::Tuple{}, tail::Tuple{}) = ()
t = ((1, 2), ((3, 4), 5))
@code_warntype cat_nested(t, ()) |
Ok, this is bizarre. I'm getting a state-dependent order-of-compilation difference in inference: $ ./julia
_
_ _ _(_)_ | A fresh approach to technical computing
(_) | (_) (_) | Documentation: https://docs.julialang.org
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 0.7.0-beta.283 (2018-07-12 22:44 UTC)
_/ |\__'_|_|_|\__'_| | Commit 98061abb5a* (0 days old master)
|__/ | x86_64-linux-gnu
julia> using Base: tail
cat_nested(t::Tuple) = cat_nested(t, ())
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
cat_nested(t::Tuple{}, tail::Tuple{}) = ()
t = ((1, 2), ((3, 4), 5))
((1, 2), ((3, 4), 5))
julia> @code_warntype cat_nested(t, ())
Body::NTuple{5,Int64}
4 1 ─ %1 = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64} │╻ getindex
│ getfield(%%t, 1) │╻ tail
│ %3 = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64} ││
│ %4 = Base.getfield(%1, 1, true)::Int64 ││╻ getindex
│ getfield(%1, 1) ││╻ tail
│ %6 = getfield(%1, 2)::Int64 │││
│ %7 = Base.getfield(%3, 1, true)::Tuple{Int64,Int64} │││╻╷╷╷ cat_nested
│ getfield(%3, 1) ││││╻ cat_nested
│ %9 = getfield(%3, 2)::Int64 │││││┃│ cat_nested
│ %10 = Base.getfield(%7, 1, true)::Int64 ││││││╻ cat_nested
│ getfield(%7, 1) │││││││╻ tail
│ %12 = getfield(%7, 2)::Int64 ││││││││
│ %13 = Core.tuple(%4, %6, %10, %12, %9)::NTuple{5,Int64} ││
└── return %13 │
julia> @code_warntype cat_nested(t)
Body::NTuple{5,Int64}
2 1 ─ %1 = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64} │╻╷ cat_nested
│ getfield(%%t, 1) ││╻ tail
│ %3 = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64} │││
│ %4 = Base.getfield(%1, 1, true)::Int64 │││╻ getindex
│ getfield(%1, 1) │││╻ tail
│ %6 = getfield(%1, 2)::Int64 ││││
│ %7 = Base.getfield(%3, 1, true)::Tuple{Int64,Int64} ││││╻╷╷╷ cat_nested
│ getfield(%3, 1) │││││╻ cat_nested
│ %9 = getfield(%3, 2)::Int64 ││││││┃│ cat_nested
│ %10 = Base.getfield(%7, 1, true)::Int64 │││││││╻ cat_nested
│ getfield(%7, 1) ││││││││╻ tail
│ %12 = getfield(%7, 2)::Int64 │││││││││
│ %13 = Core.tuple(%4, %6, %10, %12, %9)::NTuple{5,Int64} │││
└── return %13 │ But if I restart my session and call those two methods in a different order: $ ./julia -q
julia> using Base: tail
cat_nested(t::Tuple) = cat_nested(t, ())
cat_nested(t::Tuple, rest) = (t[1], cat_nested(tail(t), rest)...)
cat_nested(t::Tuple{Tuple,Vararg{Any}}, rest) = cat_nested(t[1], (tail(t)..., rest...))
cat_nested(t::Tuple{}, tail) = cat_nested(tail, ())
cat_nested(t::Tuple{}, tail::Tuple{}) = ()
t = ((1, 2), ((3, 4), 5))
((1, 2), ((3, 4), 5))
julia> @code_warntype cat_nested(t)
Body::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
2 1 ─ %1 = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64} │╻╷ cat_nested
│ getfield(%%t, 1) ││╻ tail
│ %3 = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64} │││
│ %4 = Core.tuple(%3)::Tuple{Tuple{Tuple{Int64,Int64},Int64}} ││
│ %5 = Base.getfield(%1, 1, true)::Int64 │││╻ getindex
│ %6 = Core.tuple(%5)::Tuple{Int64} │││
│ getfield(%1, 1) │││╻ tail
│ %8 = getfield(%1, 2)::Int64 ││││
│ %9 = Core.tuple(%8)::Tuple{Int64} ││││
│ %10 = invoke Main.cat_nested(%9::Tuple{Int64}, %4::Tuple{Tuple{Tuple{Int64,Int64},Int64}})::Tuple{Int64,Int64,Int64,Vararg{Any,N} where N}
│ %11 = Core._apply(Core.tuple, %6, %10)::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
└── return %11 │
julia> @code_warntype cat_nested(t, ())
Body::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
4 1 ─ %1 = Base.getfield(%%t, 1, true)::Tuple{Int64,Int64} │╻ getindex
│ getfield(%%t, 1) │╻ tail
│ %3 = getfield(%%t, 2)::Tuple{Tuple{Int64,Int64},Int64} ││
│ %4 = Core.tuple(%3)::Tuple{Tuple{Tuple{Int64,Int64},Int64}} │
│ %5 = Base.getfield(%1, 1, true)::Int64 ││╻ getindex
│ %6 = Core.tuple(%5)::Tuple{Int64} ││
│ getfield(%1, 1) ││╻ tail
│ %8 = getfield(%1, 2)::Int64 │││
│ %9 = Core.tuple(%8)::Tuple{Int64} │││
│ %10 = invoke Main.cat_nested(%9::Tuple{Int64}, %4::Tuple{Tuple{Tuple{Int64,Int64},Int64}})::Tuple{Int64,Int64,Int64,Vararg{Any,N} where N}
│ %11 = Core._apply(Core.tuple, %6, %10)::Tuple{Int64,Int64,Int64,Int64,Vararg{Any,N} where N}
└── return %11 │ |
Consider this MWE, a similar case to the OP: using BenchmarkTools
struct AugmentedState{X, Q}
x::X
q::Q
end
_state(x::AugmentedState) = x.x
_quad(x::AugmentedState) = x.q
_state(x) = x
_quad(x) = x
if VERSION > v"0.6.5"
using Printf
@inline Broadcast.broadcastable(x::AugmentedState) = x
Base.ndims(::Type{<:AugmentedState}) = 0
@inline function Broadcast.materialize!(dest::AugmentedState, bc::Broadcast.Broadcasted)
bcf = Broadcast.flatten(bc)
Broadcast.broadcast!(bcf.f, _state(dest), map(_state, bcf.args)...)
Broadcast.broadcast!(bcf.f, _quad(dest), map(_quad, bcf.args)...)
return dest
end
else
@inline function Base.Broadcast.broadcast!(f, dest::AugmentedState, args...)
broadcast!(f, _state(dest), map(_state, args)...)
broadcast!(f, _quad(dest), map(_quad, args)...)
return dest
end
end
a = AugmentedState(rand(100), rand(100))
b = AugmentedState(rand(100), rand(100))
c = AugmentedState(rand(100), rand(100))
d = AugmentedState(rand(100), rand(100))
bar(a, b, c, d) = (a .= a .+ 2.0.*b .+ 5.0.*c .- 7.0.*d; a)
t1 = @belapsed $bar($a, $b, $c, $d)
t2 = @belapsed (bar($a.x, $b.x, $c.x, $d.x); bar($a.q, $b.q, $c.q, $d.q))
@printf "AugumentedState = %7d ns\n" t1*10^9
@printf "Arrays = %7d ns\n" t2*10^9
@printf "Ratio = %7d x \n" t1/t2 Using AugumenteState = 615 ns
Arrays = 616 ns
Ratio = 1 x while on latest master I get: AugumentedState = 13962 ns
Arrays = 256 ns
Ratio = 55 x Checking out #28111 does not seem to make a large difference for this case, though. |
Yes, using latest master, the performance of the code in my previous comment does not improve and has high allocations. |
Hi @mbauman and others, so #28111 definitely fixed the code in my original post, but the following still does not work : using Test, SparseArrays
foo(a) = @. a*a*a*a*a*a*a
bar(a) = @. a*(a*a)*a*a*a*a
@inferred foo(spzeros(1)) # inferred
@inferred bar(spzeros(1)) # not inferred (since it was pointed out above that SparseArrays use Oddly enough (and hopefully a hint), the parenthesis around literally any other pair of 2 I think this is different than #28126 / @gasagna's example above because there there is no inference failure, while here there is. |
Messing around some more, I think I've reduced it further to what might be the same "order-of-compilation difference in inference" mentioned in #27988 (comment) above, although that was before #28111, here I'm after (specifically 0.7rc2) Here's a MWE. This will infer correctly: using Test
using Base.Broadcast: cat_nested, Broadcasted
@inferred cat_nested(1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)
@inferred cat_nested(Broadcasted(*, (1, Broadcasted(*, (2, 3)), 4, 5, 6, 7))) and in a fresh session, this will fail (this is the same thing as above with the 3rd line commented): using Test
using Base.Broadcast: cat_nested, Broadcasted
#@inferred cat_nested(1, Broadcasted(*, (2, 3)), 4, 5, 6, 7)
@inferred cat_nested(Broadcasted(*, (1, Broadcasted(*, (2, 3)), 4, 5, 6, 7))) with
That failure in the second case is exactly what's causing the failure for |
Curious if anyone has made progress on this since the dust has settled after JuliaCon? After more reading / digging, I venture to guess that what is happening here is exactly what is described in the "Independence of the cached result" section from https://juliacomputing.com/blog/2017/05/15/inference-converage2.html which "remains an unsolved problem". This is probably clear to the people that might be able to fix this anyway (if I'm right that that's what it is), but I figure worth mentioning. |
flatten
(eg SparseArrays)
Possibly related MWE - no type instability, but some extra allocations show up depending on order of operations for types which use foo1(x,y) = @. x*(x + y) + y*y
foo2(x,y) = @. (x + y)*x + y*y
foo3(x,y) = @. x*x + x*y + y*y Using StaticArrays, using StaticArrays
x = @SVector [1.0]
y = @SVector [2.0]
@btime foo1($x,$y); # 478.870 ns (25 allocations: 480 bytes)
@btime foo2($x,$y); # 2.949 ns (0 allocations: 0 bytes) Using SparseArrays using SparseArrays
x = sparsevec([1.0])
y = sparsevec([2.0]) only julia> @btime foo1($x,$y);
1.111 μs (47 allocations: 1.09 KiB)
julia> @btime foo2($x,$y);
1.124 μs (29 allocations: 1.20 KiB)
julia> @btime foo3($x,$y);
124.881 ns (2 allocations: 192 bytes) |
Does that still occur if you use |
Yeah: for StaticArrays arguments julia> foo1(x,y) = DiffEqBase.@.. x*(x + y) + y*y
foo1 (generic function with 1 method)
julia> foo2(x,y) = DiffEqBase.@.. (x + y)*x + y*y
foo2 (generic function with 1 method)
julia> @btime foo1($x,$y); # 478.870 ns (25 allocations: 480 bytes)
477.878 ns (25 allocations: 480 bytes)
julia> @btime foo2($x,$y); # 2.949 ns (0 allocations: 0 bytes)
1.377 ns (0 allocations: 0 bytes) |
Inference seems to be failing. Saw a lot of |
…d and inlined) (#43322) A follow up attemp to fix #27988. (close #47493 close #50554) Examples: ```julia julia> using LazyArrays julia> bc = @~ @. 1*(1 + 1) + 1*1; julia> bc2 = @~ 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3; ``` On master: <details><summary> click for details </summary> <p> ```julia julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1) CodeInfo( 1 ─ %1 = Core.getfield(args, 1)::Int64 │ %2 = Core.getfield(args, 2)::Int64 │ %3 = Core.getfield(args, 3)::Int64 │ %4 = Core.getfield(args, 4)::Int64 │ %5 = Core.getfield(args, 5)::Int64 │ %6 = invoke Base.Broadcast.var"#13#14"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(+)}}(Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(+)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), +))(%1::Int64, %2::Int64, %3::Vararg{Int64}, %4, %5)::Tuple{Int64, Int64, Vararg{Int64}} │ %7 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %6)::Tuple{Int64, Int64} │ %8 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %6)::Tuple{Vararg{Int64}} │ %9 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#9#11", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(*)}(Base.Broadcast.var"#9#11"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), *), %8)::Tuple{Int64} │ %10 = Core.getfield(%7, 1)::Int64 │ %11 = Core.getfield(%7, 2)::Int64 │ %12 = Base.mul_int(%10, %11)::Int64 │ %13 = Core.getfield(%9, 1)::Int64 │ %14 = Base.add_int(%12, %13)::Int64 └── return %14 ) => Int64 julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3)) CodeInfo( 1 ─ %1 = Core.getfield(args, 1)::Int64 │ %2 = Core.getfield(args, 2)::Int64 │ %3 = Core.getfield(args, 3)::Int64 │ %4 = Core.getfield(args, 5)::Int64 │ %5 = Core.getfield(args, 7)::Int64 │ %6 = Core.getfield(args, 8)::Int64 │ %7 = Core.getfield(args, 10)::Int64 │ %8 = invoke Base.Broadcast.var"#13#14"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}}(Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"()))), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()))), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"()))), Base.literal_pow))(%3::Int64, ^::Function, %4::Vararg{Any}, $(QuoteNode(Val{2}())), %5, %6, ^, %7, $(QuoteNode(Val{3}())))::Tuple{Int64, Any, Vararg{Any}} │ %9 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %8)::Tuple{Int64, Any} │ %10 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %8)::Tuple │ %11 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#15#17"(), %10)::Tuple │ %12 = Core.getfield(%9, 1)::Int64 │ %13 = Core.getfield(%9, 2)::Any │ %14 = (*)(%12, %13)::Any │ %15 = Core.tuple(%14)::Tuple{Any} │ %16 = Core._apply_iterate(Base.iterate, Core.tuple, %15, %11)::Tuple{Any, Vararg{Any}} │ %17 = Base.mul_int(%1, %2)::Int64 │ %18 = Core.tuple(%17)::Tuple{Int64} │ %19 = Core._apply_iterate(Base.iterate, Core.tuple, %18, %16)::Tuple{Int64, Any, Vararg{Any}} │ %20 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %19)::Tuple{Int64, Any} │ %21 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %19)::Tuple │ %22 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(*)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), *), %21)::Tuple{Any, Vararg{Any}} │ %23 = Core.getfield(%20, 1)::Int64 │ %24 = Core.getfield(%20, 2)::Any │ %25 = (-)(%23, %24)::Any │ %26 = Core.tuple(%25)::Tuple{Any} │ %27 = Core._apply_iterate(Base.iterate, Core.tuple, %26, %22)::Tuple{Any, Any, Vararg{Any}} │ %28 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %27)::Tuple{Any, Any} │ %29 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %27)::Tuple │ %30 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#9#11", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"#9#11"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"()))), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()))), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"()))), Base.literal_pow), %29)::Tuple{Any} │ %31 = Core.getfield(%28, 1)::Any │ %32 = Core.getfield(%28, 2)::Any │ %33 = (+)(%31, %32)::Any │ %34 = Core.getfield(%30, 1)::Any │ %35 = (+)(%33, %34)::Any └── return %35 ) => Any ``` </p> </details> On this PR ```julia julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1) CodeInfo( 1 ─ %1 = Core.getfield(args, 1)::Int64 │ %2 = Core.getfield(args, 2)::Int64 │ %3 = Core.getfield(args, 3)::Int64 │ %4 = Core.getfield(args, 4)::Int64 │ %5 = Core.getfield(args, 5)::Int64 │ %6 = Base.add_int(%2, %3)::Int64 │ %7 = Base.mul_int(%1, %6)::Int64 │ %8 = Base.mul_int(%4, %5)::Int64 │ %9 = Base.add_int(%7, %8)::Int64 └── return %9 ) => Int64 julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3)) CodeInfo( 1 ─ %1 = Core.getfield(args, 1)::Int64 │ %2 = Core.getfield(args, 2)::Int64 │ %3 = Core.getfield(args, 3)::Int64 │ %4 = Core.getfield(args, 5)::Int64 │ %5 = Core.getfield(args, 7)::Int64 │ %6 = Core.getfield(args, 8)::Int64 │ %7 = Core.getfield(args, 10)::Int64 │ %8 = Base.mul_int(%1, %2)::Int64 │ %9 = Base.mul_int(%4, %4)::Int64 │ %10 = Base.mul_int(%3, %9)::Int64 │ %11 = Base.sub_int(%8, %10)::Int64 │ %12 = Base.mul_int(%5, %6)::Int64 │ %13 = Base.add_int(%11, %12)::Int64 │ %14 = Base.mul_int(%7, %7)::Int64 │ %15 = Base.mul_int(%14, %7)::Int64 │ %16 = Base.add_int(%13, %15)::Int64 └── return %16 ) => Int64 ```
I've been using the new broadcasting API to implement broadcasting for some custom types (which btw, the new API is great and massively simplified my code, thanks!). I did however run into the following surprising inference failure which I've reduced down to the following code:
As you can see, each
NumWrapper
just holds a number, and broadcasting over e.g.a::NumWrapper .+ b::NumWrapper
becomesNumWrapper(broadcast.(+,(a.data,),(b.data,))...) = NumWrapper((a.data .+ b.data,)...)
. Note I need the.data
wrapped in a tuple in my real code because in the real caseNumWrapper
has multiple fields; this is also crucial to triggering the bug, although in the simple example above it probably seems unnecessary. In any case, you see that the seemingly unimportant addition of the parenthesis around(b*a)
spoils type stability. This is on commit 656d587.Beyond a possible solution, I'm also curious if there's a workaround or a better way to code this up, I can't say I'm 100% sure I've used the new API as intended (in particular, I've used
flatten
and the docs make it seem like this shouldn't usually be necessary). In any case, hope this helps!EDIT: Fixed a small mistake in the text describing what
a::NumWrapper .+ b::NumWrapper
became.The text was updated successfully, but these errors were encountered: