From eed45560f568c1dd2c80a3c5f774750e79658fc3 Mon Sep 17 00:00:00 2001 From: Harmen Stoppels Date: Sun, 16 Feb 2020 15:33:33 +0100 Subject: [PATCH] Special adjoint for broadcasted literal pow Currently taking the gradient of anything that contains a broadcasted literal pow adds RefValue{typeof(^)}(^) and a similar entry for the literal power itself to the IdDict. This is probably because of the special signature in the broadcasting machinery: ``` Base.broadcasted(Base.literal_pow, Main.:^, vec, Val{N}()) ``` By adding a special adjoint for broadcasting literal_pow, not only do we reduce the noise in the param's IdDict, but it also speeds up taking the gradient of basic loss functions like sum(err.^2). --- src/lib/broadcast.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 47a718a8c..2c4b94521 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -76,6 +76,11 @@ Numeric{T<:Number} = Union{T,AbstractArray{<:T}} res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y)) end +@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(Main.:^), x::Numeric, y::Val{p}) where p + y = x .^ p + y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing) +end + @adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ) @adjoint function broadcasted(::typeof(σ), x::Numeric)