Skip to content

Commit

Permalink
Merge #512 #518
Browse files Browse the repository at this point in the history
512: Gradient for power r=CarloLucibello a=mcabbott

Closes #511, closes #247, and #426.

518: Special adjoint for broadcasted literal pow r=CarloLucibello a=haampie

Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:

```
Base.broadcasted(Base.literal_pow, Main.:^, vec, %2)
```

where %2 is a Val{N} instance.

By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).

Ref #513 and also solves FluxML/Flux.jl#1018 (most of it)

Co-authored-by: Michael Abbott <me@escbook>
Co-authored-by: Harmen Stoppels <[email protected]>
  • Loading branch information
3 people authored Feb 27, 2020
3 parents 200ae53 + 78fc0d6 + 3ec5b4b commit 99fbd43
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ Numeric{T<:Number} = Union{T,AbstractArray{<:T}}
res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y))
end

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
end

@adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ)

@adjoint function broadcasted(::typeof(σ), x::Numeric)
Expand Down
7 changes: 7 additions & 0 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@ end

for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
f == :^ && continue
da, db = DiffRules.diffrule(M, f, :a, :b)
@eval begin
@adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b),
Δ ->* conj($da), Δ * conj($db))
end
end

@adjoint Base.:^(x::Number, p::Number) = x^p,
Δ ->* conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x))))
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)

@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ)
@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ)

Expand Down
17 changes: 17 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ Random.seed!(0)

@test gradient(//, 2, 3) === (1//3, -2//9)

@testset "power" begin
@test gradient(x -> x^2, -2) == (-4,)
@test gradient(x -> x^10, -1.0) == (-10,) # literal_pow
pow = 10
@test gradient(x -> x^pow, -1.0) == (-pow,)
@test gradient(p -> real(2^p), 2)[1] 4*log(2)

@test gradient(xs ->sum(xs .^ 2), [2, -1]) == ([4, -2],)
@test gradient(xs ->sum(xs .^ 10), [3, -1]) == ([10*3^9, -10],)
@test gradient(xs ->sum(xs .^ pow), [4, -1]) == ([pow*4^9, -10],)

@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,)
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] (-234 + 2im)*log(5 - 7im)
# D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate
end

@test gradtest((a,b)->sum(reim(acosh(complex(a[1], b[1])))), [-2.0], [1.0])

@test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2)
Expand All @@ -63,6 +79,7 @@ Random.seed!(0)
@test gradtest(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
@test gradtest(x -> sum(x[i] for i in 1:length(x)), randn(10))
@test_broken gradtest(x -> sum(i->x[i], 1:length(x)), randn(10)) # https://github.com/FluxML/Zygote.jl/issues/231
@test gradtest(x -> sum((i->x[i]).(1:length(x))), randn(10))

@test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
@test gradtest(x -> prod(x), (3,4))
Expand Down

0 comments on commit 99fbd43

Please sign in to comment.