From 4229760fc913624ff291b3187107b2d4f5c3d9bc Mon Sep 17 00:00:00 2001 From: lxvm Date: Tue, 2 Jan 2024 16:34:36 -0800 Subject: [PATCH 1/2] initial commit --- src/lib/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index eaa49ada2..4559237e7 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -22,7 +22,7 @@ accum(x, y) = accum(x, y, zs...) = accum(accum(x, y), zs...) accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...) -accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...) +accum(x::AbstractArray, ys::AbstractArray...) = Base.broadcast_preserving_zero_d(accum, x, ys...) @generated function accum(x::NamedTuple, y::NamedTuple) # assumes that y has no keys apart from those also in x From 65a32f516e21fa5dadf054a1ff5445e9cd436669 Mon Sep 17 00:00:00 2001 From: lxvm Date: Tue, 2 Jan 2024 16:40:11 -0800 Subject: [PATCH 2/2] add test --- test/lib/lib.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/lib/lib.jl b/test/lib/lib.jl index 0886b9969..11e64cba9 100644 --- a/test/lib/lib.jl +++ b/test/lib/lib.jl @@ -4,5 +4,6 @@ t2 = (a=1, b=2) @test Zygote.accum(t1, t2) == (a = 2, b = 4, c = 3) @test_throws ArgumentError Zygote.accum(t2, t1) + @test Zygote.accum(fill(0.0), fill(0.0)) == fill(0.0) end end