Skip to content

Commit

Permalink
Fix type stability of Broadcast.flatten
Browse files Browse the repository at this point in the history
Fixes #27988.
  • Loading branch information
mbauman authored and JeffBezanson committed Jul 20, 2018
1 parent 5d876f4 commit 67b06af
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
13 changes: 4 additions & 9 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ some cases.
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
args = cat_nested(x->x.args, bc)
args = cat_nested(bc)
# build a function `makeargs` that takes a "flat" argument list and
# and creates the appropriate input arguments for `f`, e.g.,
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
Expand All @@ -318,14 +318,9 @@ _isflat(args::NestedTuple) = false
_isflat(args::Tuple) = _isflat(tail(args))
_isflat(args::Tuple{}) = true

cat_nested(fieldextractor, bc::Broadcasted) = cat_nested(fieldextractor, fieldextractor(bc), ())

cat_nested(fieldextractor, t::Tuple, rest) =
(t[1], cat_nested(fieldextractor, tail(t), rest)...)
cat_nested(fieldextractor, t::Tuple{<:Broadcasted,Vararg{Any}}, rest) =
cat_nested(fieldextractor, cat_nested(fieldextractor, fieldextractor(t[1]), tail(t)), rest)
cat_nested(fieldextractor, t::Tuple{}, tail) = cat_nested(fieldextractor, tail, ())
cat_nested(fieldextractor, t::Tuple{}, tail::Tuple{}) = ()
cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
cat_nested() = ()

make_makeargs(bc::Broadcasted) = make_makeargs(()->(), bc.args)
@inline function make_makeargs(makeargs, t::Tuple)
Expand Down
11 changes: 11 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,17 @@ let X = zeros(2, 3)
@test X == [1 1 1; 2 2 2]
end

# issue #27988: inference of Broadcast.flatten
using .Broadcast: Broadcasted
let
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
end

# Issue #26127: multiple splats in a fused dot-expression
let f(args...) = *(args...)
x, y, z = (1,2), 3, (4, 5)
Expand Down

0 comments on commit 67b06af

Please sign in to comment.