Skip to content

Commit

Permalink
Special adjoint for broadcasted literal pow
Browse files Browse the repository at this point in the history
Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:

```
Base.broadcasted(Base.literal_pow, Main.:^, vec, Val{N}())
```

By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).
  • Loading branch information
haampie committed Feb 26, 2020
1 parent af498fa commit 3ec5b4b
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ Numeric{T<:Number} = Union{T,AbstractArray{<:T}}
res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y))
end

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
end

@adjoint broadcasted(::typeof(identity), x::Numeric) = x, Δ -> (nothing, Δ)

@adjoint function broadcasted(::typeof(σ), x::Numeric)
Expand Down

0 comments on commit 3ec5b4b

Please sign in to comment.