Skip to content

Commit

Permalink
Special adjoint for broadcasted literal pow
Browse files Browse the repository at this point in the history
Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:

```
Base.broadcasted(Base.literal_pow, Main.:^, 2, %2)
```

where %2 is a Val{N} instance.

By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).
  • Loading branch information
haampie committed Feb 26, 2020
1 parent ab08984 commit 0ccc53a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Numeric{T<:Number} = Union{T,AbstractArray{<:T}}
res, Δ -> (nothing, unbroadcast(x, Δ ./ y), unbroadcast(y, -Δ .* res ./ y))
end

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(Main.:^), x::Numeric, y::Val{p}) where p
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, ::Val{p}) where p
y = x .^ p
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p - 1)), nothing)
end
Expand Down

0 comments on commit 0ccc53a

Please sign in to comment.