Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special adjoint for broadcasted literal pow #518

Merged
merged 1 commit into from
Feb 27, 2020

Conversation

haampie
Copy link
Contributor

@haampie haampie commented Feb 16, 2020

Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:

Base.broadcasted(Base.literal_pow, Main.:^, vec, %2)

where %2 is a Val{N} instance.

By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).

Ref #513 and also solves FluxML/Flux.jl#1018 (most of it)

@mcabbott
Copy link
Member

Minimal timing test, which improves by 100x:

@btime gradient(x -> sum(x .^ 2), x0)  setup=(x0=collect(1.0:10^6))

However to get the right answers, I think it needs to be more like this:

import Base: broadcasted
using Zygote: @adjoint
Numeric{T<:Number} = Union{T,AbstractArray{<:T}}

@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, ::Val{p}) where {p}
  y = Base.literal_pow.(^, x, Val(p))
  y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p-1)), nothing)
end

@haampie
Copy link
Contributor Author

haampie commented Feb 17, 2020

Oops, yeah. Should be captured by @mcabbott tests from #512 :)

Should be fine now!

@CarloLucibello
Copy link
Member

@oxinabox could you merge this?

@haampie
Copy link
Contributor Author

haampie commented Feb 26, 2020

@mcabbott any reason to write Base.literal_pow.(^, x, Val(p)) instead of x .^ p?

@mcabbott
Copy link
Member

mcabbott commented Feb 26, 2020

Mostly a general feeling that one shouldn't change the forward pass without good reason. But I think literal_pow is also often quite a bit quicker, it may even be worthwhile using this for the backward pass too:

julia> A = randn(10,2,5,7);

julia> my_power(A, p=2) = A .^ p;

julia> @btime my_power($A);
  14.790 μs (2 allocations: 5.67 KiB)

julia> @btime Base.literal_pow.(^, $A, Val(2)); 
  698.965 ns (2 allocations: 5.67 KiB)

Or maybe this is going to get optimised anyway?

julia> my_power_val(A, ::Val{p} = Val(2)) where {p} = A.^p;

julia> @btime my_power_val($A);
  869.286 ns (2 allocations: 5.67 KiB)

@haampie
Copy link
Contributor Author

haampie commented Feb 26, 2020

Oh... of course. Somehow I thought the compiler would fill in the p as a constant and it would result in literal_pow anyways. But it's not. I'll just make sure it uses literal_pow explicitly then.

Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:

```
Base.broadcasted(Base.literal_pow, Main.:^, vec, Val{N}())
```

By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).
@mcabbott
Copy link
Member

It may actually catch it here, compare maximal:

julia> @adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, ::Val{p}) where {p}
         y = Base.literal_pow.(^, x, Val(p))
         y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(Base.literal_pow.(^, x, Val(p-1))), nothing)
       end

julia> @btime gradient(x -> sum(x .^ 2), x0)  setup=(x0=collect(1.0:10^6))
  2.247 ms (20 allocations: 15.26 MiB)
([2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0  …  1.999982e6, 1.999984e6, 1.999986e6, 1.999988e6, 1.99999e6, 1.999992e6, 1.999994e6, 1.999996e6, 1.999998e6, 2.0e6],)

and minimal:

julia> @adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, ::Val{p}) where {p}
         y = x .^ p
         y, ȳ -> (nothing, nothing, ȳ .* p .* conj.(x .^ (p-1)), nothing)
       end

julia> @btime gradient(x -> sum(x .^ 2), x0)  setup=(x0=collect(1.0:10^6))
  2.307 ms (20 allocations: 15.26 MiB)
([2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0  …  1.999982e6, 1.999984e6, 1.999986e6, 1.999988e6, 1.99999e6, 1.999992e6, 1.999994e6, 1.999996e6, 1.999998e6, 2.0e6],)

So maybe I'm just superstitious! Both a bit quicker than 113.114 ms without this.

@haampie
Copy link
Contributor Author

haampie commented Feb 26, 2020

Let's just go for the exact same expression in the forward pass, then there's no surprises.

@haampie
Copy link
Contributor Author

haampie commented Feb 26, 2020

Actually, there is a difference for p = -1 in which case literal_pow it is computed via Base.div_float. So, let's definitely stick to literal_pow. This PR should be ready to go then :)

@CarloLucibello
Copy link
Member

Nice, thanks!

bors r+

@bors
Copy link
Contributor

bors bot commented Feb 27, 2020

Build succeeded

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants