diff --git a/src/lib/lib.jl b/src/lib/lib.jl index eaa49ada2..4559237e7 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -22,7 +22,7 @@ accum(x, y) = accum(x, y, zs...) = accum(accum(x, y), zs...) accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...) -accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...) +accum(x::AbstractArray, ys::AbstractArray...) = Base.broadcast_preserving_zero_d(accum, x, ys...) @generated function accum(x::NamedTuple, y::NamedTuple) # assumes that y has no keys apart from those also in x diff --git a/test/lib/lib.jl b/test/lib/lib.jl index 0886b9969..11e64cba9 100644 --- a/test/lib/lib.jl +++ b/test/lib/lib.jl @@ -4,5 +4,6 @@ t2 = (a=1, b=2) @test Zygote.accum(t1, t2) == (a = 2, b = 4, c = 3) @test_throws ArgumentError Zygote.accum(t2, t1) + @test Zygote.accum(fill(0.0), fill(0.0)) == fill(0.0) end end