From 91f2d47851a146047626e75a192bd3f45eee4840 Mon Sep 17 00:00:00 2001 From: Lukas Billera <125421968+billera@users.noreply.github.com> Date: Tue, 5 Nov 2024 12:11:51 +0100 Subject: [PATCH] Epsilon change in normalise for stability (#2421) * epsilon change for stability * Change comment for eps Co-authored-by: Carlo Lucibello --------- Co-authored-by: Carlo Lucibello Co-authored-by: Carlo Lucibello --- src/layers/stateless.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 4fb739e0c3..ed85b1a92c 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,10 +1,10 @@ """ - normalise(x; dims=ndims(x), eps=1e-5) + normalise(x; dims=ndims(x), eps=1f-5) Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`. Per default, `dims` is the last dimension. -`eps` is a small term added to the denominator for numerical stability. +`eps` is a small term added to the variance for numerical stability. # Examples ```jldoctest @@ -34,10 +34,11 @@ julia> isapprox(std(y; dims=1, corrected=false), ones(1, 10), atol=1e-5) true ``` """ -@inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5)) +@inline function normalise(x::AbstractArray; dims=ndims(x), eps=1f-5) μ = mean(x, dims=dims) - σ = std(x, dims=dims, mean=μ, corrected=false) - return @. (x - μ) / (σ + eps) + σ² = var(x, dims=dims, mean=μ, corrected=false) + ε = ofeltype(x, eps) + return @. (x - μ) / sqrt(σ² + ε^2) end """