diff --git a/src/lib/array.jl b/src/lib/array.jl index 745acae47..3f9d4a25b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -271,15 +271,15 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x) function prodfunc(xs, dy) @assert length(first(dy)) == length(xs) ndim = map(Zygote._ndims, xs) - cdim = cumsum((0, ndim[begin:end-1]...)) + cdim = cumsum((1, ndim[begin:end-1]...)) getters = ntuple(n -> StaticGetter{n}(), Val(length(xs))) - dims = Vector{Int}(undef, length(xs)) - map(first(dy), xs, cdim, ndim, getters) do dyn, x, cd, nd, getter + map(first(dy), xs, cdim, getters) do dyn, x, cd, getter dyn === nothing && return nothing - append!(empty!(dims), 1:cd, cd+nd+1:ndims(dy)) + nd = _ndims(x) + dims = nd == 0 ? (:) : ntuple(i -> isum(sin, Diagonal(x)), rand(3); rrule_f=rrule_ for p in (1.0, fill(1.0), [1.0]) @test gradient(p -> sum([x*q for q in p, x in 1:3]), p) == (6p,) - # @test gradient(p -> sum(x*q for (q, p) in Iterators.product(p, 1:3)), p) == (6.0,) + @test gradient(p -> sum(x*q for (q, x) in Iterators.product(p, 1:3)), p) == (6p,) end @test gradient(x -> sum(broadcast(prod, Iterators.product(x,x))), ones(4)) == (2*4ones(4),)