From 790eb84d69f857cab68c96b2c8f05c37a0663ea5 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Mon, 18 Dec 2023 19:33:56 -0800 Subject: [PATCH] Non-diff shape handling in norm layers This reduces some latency when using Zygote. --- src/layers/normalise.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index d38be00df3..1c8fbff5a1 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -226,7 +226,9 @@ function _norm_layer_forward( l, x::AbstractArray{T, N}; reduce_dims, affine_shape, ) where {T, N} if !_isactive(l, x) && l.track_stats # testmode with tracked stats - stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + stats_shape = ChainRulesCore.ignore_derivatives() do + ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + end μ = reshape(l.μ, stats_shape) σ² = reshape(l.σ², stats_shape) else # trainmode or testmode without tracked stats @@ -347,7 +349,9 @@ trainable(bn::BatchNorm) = hasaffine(bn) ? (β = bn.β, γ = bn.γ) : (;) function (BN::BatchNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(BN, x, N-1 => BN.chs) reduce_dims = [1:N-2; N] - affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + affine_shape = ChainRulesCore.ignore_derivatives() do + ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + end return _norm_layer_forward(BN, x; reduce_dims, affine_shape) end @@ -439,7 +443,9 @@ trainable(in::InstanceNorm) = hasaffine(in) ? (β = in.β, γ = in.γ) : (;) function (l::InstanceNorm)(x::AbstractArray{T,N}) where {T,N} _size_check(l, x, N-1 => l.chs) reduce_dims = 1:N-2 - affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + affine_shape = ChainRulesCore.ignore_derivatives() do + ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + end return _norm_layer_forward(l, x; reduce_dims, affine_shape) end @@ -456,10 +462,10 @@ end """ GroupNorm(channels::Int, G::Int, λ = identity; - initβ = zeros32, + initβ = zeros32, initγ = ones32, - affine = true, - eps = 1f-5, + affine = true, + eps = 1f-5, momentum = 0.1f0) [Group Normalization](https://arxiv.org/abs/1803.08494) layer. @@ -538,12 +544,14 @@ function GroupNorm(chs::Int, G::Int, λ=identity; end function (gn::GroupNorm)(x::AbstractArray) - _size_check(gn, x, ndims(x)-1 => gn.chs) + _size_check(gn, x, ndims(x)-1 => gn.chs) sz = size(x) x2 = reshape(x, sz[1:end-2]..., sz[end-1]÷gn.G, gn.G, sz[end]) N = ndims(x2) # == ndims(x)+1 reduce_dims = 1:N-2 - affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x2, i) : 1, N) + affine_shape = ChainRulesCore.ignore_derivatives() do + ntuple(i -> i ∈ (N-1, N-2) ? size(x2, i) : 1, N) + end x3 = _norm_layer_forward(gn, x2; reduce_dims, affine_shape) return reshape(x3, sz) end