-
-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote gives extra gradient entries for BatchNorm #1018
Comments
I've been looking into this a bit. By writing --- a/src/layers/normalise.jl
+++ b/src/layers/normalise.jl
@@ -153,7 +153,7 @@ function (BN::BatchNorm)(x)
T = eltype(x)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes)
- σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
+ σ² = sum((x .- μ) .* (x .- μ), dims = axes) ./ m
ϵ = convert(T, BN.ϵ)
# update moving mean/std
mtm = BN.momentum the following two entries are no longer present:
If I change
|
And the julia> mutable struct MutableLayer{T}; a::T; end;
julia> struct ImmutableLayer{T}; a::T; end;
julia> (layer::ImmutableLayer)() = sum(layer.a);
julia> (layer::MutableLayer)() = sum(layer.a);
julia> Flux.trainable(a::ImmutableLayer) = (a.a,);
julia> Flux.trainable(a::MutableLayer) = (a.a,);
julia> Flux.@functor ImmutableLayer;
julia> Flux.@functor MutableLayer;
julia> mutable_layer = MutableLayer(rand(1));
julia> immutable_layer = ImmutableLayer(rand(1));
julia> gradient(mutable_layer, params(mutable_layer)).grads
IdDict{Any,Any} with 2 entries:
MutableLayer{Array{Float64,1}}([0.500088]) => RefValue{Any}((a = [1.0],))
[0.500088] => [1.0]
julia> gradient(immutable_layer, params(immutable_layer)).grads
IdDict{Any,Any} with 1 entry:
[0.619024] => [1.0] |
MWE for the julia> f() = 1 .^ 2;
julia> gradient(f, params()).grads
IdDict{Any,Any} with 2 entries:
RefValue{typeof(^)}(^) => RefValue{Any}((x = nothing,))
RefValue{Val{2}}(Val{2}()) => RefValue{Any}((x = nothing,))
|
I think this is known? |
512: Gradient for power r=CarloLucibello a=mcabbott Closes #511, closes #247, and #426. 518: Special adjoint for broadcasted literal pow r=CarloLucibello a=haampie Currently taking the gradient of anything that contains a broadcasted literal pow adds RefValue{typeof(^)}(^) and a similar entry for the literal power itself to the IdDict. This is probably because of the special signature in the broadcasting machinery: ``` Base.broadcasted(Base.literal_pow, Main.:^, vec, %2) ``` where %2 is a Val{N} instance. By adding a special adjoint for broadcasting literal_pow, not only do we reduce the noise in the param's IdDict, but it also speeds up taking the gradient of basic loss functions like sum(err.^2). Ref #513 and also solves FluxML/Flux.jl#1018 (most of it) Co-authored-by: Michael Abbott <me@escbook> Co-authored-by: Harmen Stoppels <[email protected]>
FluxML/Zygote.jl#518 fixes the The only entry left from @xukai92's example would be |
gives
FYI,
ps
is as expected:The text was updated successfully, but these errors were encountered: