Skip to content

Commit

Permalink
rebase and make dims a tuple again
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Jan 16, 2024
1 parent b5bba72 commit 455d24c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,15 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)
function prodfunc(xs, dy)
@assert length(first(dy)) == length(xs)
ndim = map(Zygote._ndims, xs)
cdim = cumsum((0, ndim[begin:end-1]...))
cdim = cumsum((1, ndim[begin:end-1]...))
getters = ntuple(n -> StaticGetter{n}(), Val(length(xs)))
dims = Vector{Int}(undef, length(xs))
map(first(dy), xs, cdim, ndim, getters) do dyn, x, cd, nd, getter
map(first(dy), xs, cdim, getters) do dyn, x, cd, getter
dyn === nothing && return nothing
append!(empty!(dims), 1:cd, cd+nd+1:ndims(dy))
nd = _ndims(x)
dims = nd == 0 ? (:) : ntuple(i -> i<cd ? i : i+nd, Val(ndims(dy)-nd))
init = map(zero, dyn) # allows for tuples, which accum can add:
red = mapreduce(getter, accum, dy; dims=_ndims(x) == 0 ? (:) : dims, init=init)
return _project(x, _ndims(x) == 0 ? red : reshape(red, axes(x)))
red = mapreduce(getter, accum, dy; dims, init)
return _project(x, nd == 0 ? red : reshape(red, axes(x)))
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_

for p in (1.0, fill(1.0), [1.0])
@test gradient(p -> sum([x*q for q in p, x in 1:3]), p) == (6p,)
# @test gradient(p -> sum(x*q for (q, p) in Iterators.product(p, 1:3)), p) == (6.0,)
@test gradient(p -> sum(x*q for (q, x) in Iterators.product(p, 1:3)), p) == (6p,)
end

@test gradient(x -> sum(broadcast(prod, Iterators.product(x,x))), ones(4)) == (2*4ones(4),)
Expand Down

0 comments on commit 455d24c

Please sign in to comment.