-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Special adjoint for broadcasted literal pow #518
Conversation
Minimal timing test, which improves by 100x: @btime gradient(x -> sum(x .^ 2), x0) setup=(x0=collect(1.0:10^6)) However to get the right answers, I think it needs to be more like this: import Base: broadcasted
using Zygote: @adjoint
Numeric{T<:Number} = Union{T,AbstractArray{<:T}}
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, ::Val{p}) where {p}
y = Base.literal_pow.(^, x, Val(p))
y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p-1)), nothing)
end |
c0e58dd
to
eed4556
Compare
@oxinabox could you merge this? |
0ccc53a
to
46da426
Compare
@mcabbott any reason to write |
Mostly a general feeling that one shouldn't change the forward pass without good reason. But I think
Or maybe this is going to get optimised anyway?
|
Oh... of course. Somehow I thought the compiler would fill in the |
Currently taking the gradient of anything that contains a broadcasted literal pow adds RefValue{typeof(^)}(^) and a similar entry for the literal power itself to the IdDict. This is probably because of the special signature in the broadcasting machinery: ``` Base.broadcasted(Base.literal_pow, Main.:^, vec, Val{N}()) ``` By adding a special adjoint for broadcasting literal_pow, not only do we reduce the noise in the param's IdDict, but it also speeds up taking the gradient of basic loss functions like sum(err.^2).
46da426
to
3ec5b4b
Compare
It may actually catch it here, compare maximal:
and minimal:
So maybe I'm just superstitious! Both a bit quicker than |
Let's just go for the exact same expression in the forward pass, then there's no surprises. |
Actually, there is a difference for |
Nice, thanks! bors r+ |
Build succeeded |
Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:
where %2 is a Val{N} instance.
By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).
Ref #513 and also solves FluxML/Flux.jl#1018 (most of it)