Skip to content

Commit

Permalink
simplest prod gradient, again
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Abbott authored and haampie committed Feb 26, 2020
1 parent d6d8151 commit ab08984
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
10 changes: 3 additions & 7 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,9 @@ end
return sum(abs2, X; dims=dims), Δ::Union{Number, AbstractArray}->(nothing, ((2Δ) .* X))
end

@adjoint function prod(xs::AbstractArray{<:Number}; dims = :)
if dims === (:)
prod(xs), Δ -> (prod(xs) ./ xs .* Δ,)
else
prod(xs, dims = dims),
Δ -> (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ,)
end
@adjoint function prod(xs; dims = :)
p = prod(xs; dims = dims)
p, Δ -> (p ./ xs .* Δ,)
end

function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
Expand Down
6 changes: 3 additions & 3 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ Random.seed!(0)
@test gradtest(x -> sum(x[i] for i in 1:length(x)), randn(10))
@test_broken gradtest(x -> sum(i->x[i], 1:length(x)), randn(10)) # https://github.com/FluxML/Zygote.jl/issues/231

@test_broken gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4,5))
@test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4))

@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
Expand Down Expand Up @@ -112,7 +112,7 @@ end
end

@testset "conv: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(repeat([10], spatial_rank)..., 3, 2)
x = rand(repeat([5], spatial_rank)..., 3, 2)
w = rand(repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
@test gradtest((x, w) -> conv(x, w, cdims), x, w)
Expand Down

0 comments on commit ab08984

Please sign in to comment.