Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote gives extra gradient entries for BatchNorm #1018

Open
xukai92 opened this issue Feb 3, 2020 · 5 comments
Open

Zygote gives extra gradient entries for BatchNorm #1018

xukai92 opened this issue Feb 3, 2020 · 5 comments

Comments

@xukai92
Copy link

xukai92 commented Feb 3, 2020

using Flux

X = rand(2, 5)

layer = BatchNorm(2)
ps = params(layer)

gs = gradient(ps) do
    sum(layer(X))
end

gs.grads

gives

IdDict{Any,Any} with 5 entries:
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  BatchNorm(2)               => RefValue{Any}((λ = nothing, β = [5.0, 5.0], γ =
  Float32[1.0, 1.0]          => [-9.22873e-16, 4.44089e-16]
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
  Float32[0.0, 0.0]          => [5.0, 5.0]

FYI, ps is as expected:

Params([Float32[0.0, 0.0], Float32[1.0, 1.0]])
@haampie
Copy link

haampie commented Feb 13, 2020

I've been looking into this a bit. By writing

--- a/src/layers/normalise.jl
+++ b/src/layers/normalise.jl
@@ -153,7 +153,7 @@ function (BN::BatchNorm)(x)
     T = eltype(x)
     axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
     μ = mean(x, dims = axes)
-    σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
+    σ² = sum((x .- μ) .* (x .- μ), dims = axes) ./ m
     ϵ = convert(T, BN.ϵ)
     # update moving mean/std
     mtm = BN.momentum

the following two entries are no longer present:

  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  BatchNorm(2)               => RefValue{Any}((λ = nothing, β = [5.0, 5.0], γ =…

If I change .^2 to .^2.0f0, the gradient throws:

julia> example = gradient(ps) do
           sum(layer(X))
       end
ERROR: DomainError with -0.14608962594708508:
log will only return a complex result if called with a complex argument. Try log(Complex(x)).

@haampie
Copy link

haampie commented Feb 13, 2020

And the BatchNorm(2) entry seems to happen because of mutable structs?

julia> mutable struct MutableLayer{T}; a::T; end;

julia> struct ImmutableLayer{T}; a::T; end;

julia> (layer::ImmutableLayer)() = sum(layer.a);

julia> (layer::MutableLayer)() = sum(layer.a);

julia> Flux.trainable(a::ImmutableLayer) = (a.a,);

julia> Flux.trainable(a::MutableLayer) = (a.a,);

julia> Flux.@functor ImmutableLayer;

julia> Flux.@functor MutableLayer;

julia> mutable_layer = MutableLayer(rand(1));

julia> immutable_layer = ImmutableLayer(rand(1));

julia> gradient(mutable_layer, params(mutable_layer)).grads
IdDict{Any,Any} with 2 entries:
  MutableLayer{Array{Float64,1}}([0.500088]) => RefValue{Any}((a = [1.0],))
  [0.500088]                                 => [1.0]

julia> gradient(immutable_layer, params(immutable_layer)).grads
IdDict{Any,Any} with 1 entry:
  [0.619024] => [1.0]

@haampie
Copy link

haampie commented Feb 13, 2020

MWE for the ^ issue:

julia> f() = 1 .^ 2;

julia> gradient(f, params()).grads
IdDict{Any,Any} with 2 entries:
  RefValue{typeof(^)}(^)     => RefValue{Any}((x = nothing,))
  RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))

@DhairyaLGandhi
Copy link
Member

I think this is known?

bors bot added a commit to FluxML/Zygote.jl that referenced this issue Feb 27, 2020
512: Gradient for power r=CarloLucibello a=mcabbott

Closes #511, closes #247, and #426.

518: Special adjoint for broadcasted literal pow r=CarloLucibello a=haampie

Currently taking the gradient of anything that contains a broadcasted
literal pow adds RefValue{typeof(^)}(^) and a similar entry for the
literal power itself to the IdDict. This is probably because of the
special signature in the broadcasting machinery:

```
Base.broadcasted(Base.literal_pow, Main.:^, vec, %2)
```

where %2 is a Val{N} instance.

By adding a special adjoint for broadcasting literal_pow, not only
do we reduce the noise in the param's IdDict, but it also speeds
up taking the gradient of basic loss functions like sum(err.^2).

Ref #513 and also solves FluxML/Flux.jl#1018 (most of it)

Co-authored-by: Michael Abbott <me@escbook>
Co-authored-by: Harmen Stoppels <[email protected]>
@haampie
Copy link

haampie commented Feb 27, 2020

FluxML/Zygote.jl#518 fixes the f() = 1 .^ 2 example and solves this issue for the most part.

The only entry left from @xukai92's example would be BatchNorm(2) => ..., but I don't know how to solve that one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants