From 0b1ed6db85b65f220454bd3a4abe2c164b225ff6 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 17 Feb 2022 17:31:29 +0530 Subject: [PATCH 01/12] Added truncated normal initialisation --- Project.toml | 1 + src/Flux.jl | 2 +- src/utils.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6695e22267..c99d681ef1 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/Flux.jl b/src/Flux.jl index f7696b6549..2b204567d0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -3,7 +3,7 @@ module Flux # Zero Flux Given using Base: tail -using Statistics, Random, LinearAlgebra +using Statistics, Random, LinearAlgebra, SpecialFunctions using Zygote, MacroTools, ProgressLogging, Reexport using MacroTools: @forward @reexport using NNlib diff --git a/src/utils.jl b/src/utils.jl index 67ca92f597..da8c0a52c8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -232,6 +232,7 @@ true * sparse initialization: [`sparse_init`](@ref Flux.sparse_init) # References + [1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120 """ @@ -301,6 +302,45 @@ end sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...) sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) +""" + truncated_normal([rng=GLOBAL_RNG], dims...; μ = 0, σ = 1, a = -2., b = 2.) + +Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution. +The values are effectively drawn from the normal distribution with mean `μ` and standard deviation +`σ`, with values outside `[a, b]` redrawn until they are within the bounds. The method used for +generating the random values works best when `a ≤ mean ≤ b`. + +# Examples +```jldoctest; setup = :(using Random; Random.seed!(0)) +julia> Flux.truncated_normal(3, 2) +3×2 Matrix{Float32}: + -0.113785 -0.627307 + -0.676033 0.198423 + 0.509005 -0.554339 +``` + +# References +[1] Burkardt, John. "The Truncated Normal Distribution" +[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf). +Department of Scientific Computing website. +""" +function truncated_normal(rng::AbstractRNG, dims...; μ = 0., σ = 1., a = -2., b = 2.) + norm_cdf(x) = 0.5 * (1 + erf(x/√2)) + if (μ < a - 2σ) || (μ > b + 2σ) + @warn "mean is more than 2 std from [a, b] in truncated_normal. The distribution of values + may be incorrect." + end + l = norm_cdf((a - μ) / σ) + u = norm_cdf((b - μ) / σ) + x = rand(rng, dims...) * 2(u - l) .+ (2l - 1) + x = erfinv.(x) + x = clamp.(x * σ/√2 .+ μ, a, b) + return convert.(Float32, x) +end + +truncated_normal(dims::Integer...; kwargs...) = truncated_normal(Random.GLOBAL_RNG, dims...; kwargs...) +truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) + """ identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0) From 1c5cd9d5ee8031a66a7d64eae65fe723f787a94b Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 17 Feb 2022 18:55:21 +0530 Subject: [PATCH 02/12] Added tests Fix tests gaffe --- test/utils.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 1681a0df28..a441e14803 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,6 +1,6 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, - kaiming_normal, kaiming_uniform, orthogonal, + kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, sparse_init, stack, unstack, Zeros, batch, unbatch, unsqueeze using StatsBase: var, std @@ -146,6 +146,15 @@ end end end + @testset "truncated_normal" begin + for sz in [(100,), (100, 100), (2, 3, 32, 64), (2, 3, 4, 32, 64)] + v = truncated_normal(sz...) + @test -1.0 < minimum(v) < 0.0 + @test 0.0 < maximum(v) < 1.0 + @test eltype(v) == Float32 + end + end + @testset "partial_application" begin big = 1e9 From 35d4ed4bb0a7ed1121b4e4837462ab222b9b4490 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 17 Feb 2022 20:54:30 +0530 Subject: [PATCH 03/12] Work entirely with Float32 Co-authored-by: Kyle Daruwalla --- src/utils.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index da8c0a52c8..7443c546b1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -324,18 +324,17 @@ julia> Flux.truncated_normal(3, 2) [PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf). Department of Scientific Computing website. """ -function truncated_normal(rng::AbstractRNG, dims...; μ = 0., σ = 1., a = -2., b = 2.) - norm_cdf(x) = 0.5 * (1 + erf(x/√2)) +function truncated_normal(rng::AbstractRNG, dims...; μ = 0, σ = 1, a = -2, b = 2) + norm_cdf(x) = 5f-1 * (1 + erf(x/√2f0)) if (μ < a - 2σ) || (μ > b + 2σ) - @warn "mean is more than 2 std from [a, b] in truncated_normal. The distribution of values - may be incorrect." + @warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1 end l = norm_cdf((a - μ) / σ) u = norm_cdf((b - μ) / σ) - x = rand(rng, dims...) * 2(u - l) .+ (2l - 1) + x = rand(rng, Float32, dims...) * 2(u - l) .+ (2l - 1) x = erfinv.(x) - x = clamp.(x * σ/√2 .+ μ, a, b) - return convert.(Float32, x) + x = clamp.(x .* σ/√2f0 .+ μ, a, b) + return x end truncated_normal(dims::Integer...; kwargs...) = truncated_normal(Random.GLOBAL_RNG, dims...; kwargs...) From 00124739e305b66ff34d22de49f49cac35dcbd00 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 17 Feb 2022 23:07:36 +0530 Subject: [PATCH 04/12] Tweaks to API and tests --- src/utils.jl | 29 +++++++++++++++-------------- test/utils.jl | 10 ++++++---- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7443c546b1..92c1b76d6a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -303,12 +303,12 @@ sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) """ - truncated_normal([rng=GLOBAL_RNG], dims...; μ = 0, σ = 1, a = -2., b = 2.) + truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.) Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution. -The values are effectively drawn from the normal distribution with mean `μ` and standard deviation -`σ`, with values outside `[a, b]` redrawn until they are within the bounds. The method used for -generating the random values works best when `a ≤ mean ≤ b`. +The values are generated by using a truncated uniform distribution and then using the inverse CDF +for the normal distribution. The method used for generating the random values works best when +`lo ≤ mean ≤ hi`. # Examples ```jldoctest; setup = :(using Random; Random.seed!(0)) @@ -324,20 +324,21 @@ julia> Flux.truncated_normal(3, 2) [PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf). Department of Scientific Computing website. """ -function truncated_normal(rng::AbstractRNG, dims...; μ = 0, σ = 1, a = -2, b = 2) - norm_cdf(x) = 5f-1 * (1 + erf(x/√2f0)) - if (μ < a - 2σ) || (μ > b + 2σ) +function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2) + norm_cdf(x) = 0.5 * (1 + erf(x/√2)) + if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1 end - l = norm_cdf((a - μ) / σ) - u = norm_cdf((b - μ) / σ) - x = rand(rng, Float32, dims...) * 2(u - l) .+ (2l - 1) + l = norm_cdf((lo - mean) / std) + u = norm_cdf((hi - mean) / std) + x = rand(rng, dims...) * 2(u - l) .+ (2l - 1) x = erfinv.(x) - x = clamp.(x .* σ/√2f0 .+ μ, a, b) + x = f32(x .* std * √2 .+ mean) return x end truncated_normal(dims::Integer...; kwargs...) = truncated_normal(Random.GLOBAL_RNG, dims...; kwargs...) +truncated_normal(dims) = truncated_normal(Random.GLOBAL_RNG, dims...) truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) """ @@ -476,8 +477,8 @@ end Flatten a model's parameters into a single weight vector. - julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) + julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax) + Chain(Dense(10, 5, std), Dense(5, 2), softmax) julia> θ, re = destructure(m); @@ -490,7 +491,7 @@ The second return value `re` allows you to reconstruct the original network afte modifications to the weight vector (for example, with a hypernetwork). julia> re(θ .* 2) - Chain(Dense(10, 5, σ), Dense(5, 2), softmax) + Chain(Dense(10, 5, std), Dense(5, 2), softmax) """ function destructure(m) xs = Zygote.Buffer([]) diff --git a/test/utils.jl b/test/utils.jl index a441e14803..fdf589c07f 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -147,10 +147,12 @@ end end @testset "truncated_normal" begin - for sz in [(100,), (100, 100), (2, 3, 32, 64), (2, 3, 4, 32, 64)] - v = truncated_normal(sz...) - @test -1.0 < minimum(v) < 0.0 - @test 0.0 < maximum(v) < 1.0 + size = (100, 100, 100) + for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)] + v = truncated_normal(size...; mean = μ, std = σ, lo, hi) + @test isapprox(mean(v), μ; atol = 1e-2) + @test isapprox(minimum(v), lo; atol = 1e-2) + @test isapprox(maximum(v), hi; atol = 1e-2) @test eltype(v) == Float32 end end From 46feb856ee371f707e6bcaeeb6ee7fa8700990f5 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 18 Feb 2022 08:12:45 +0530 Subject: [PATCH 05/12] Fixed RNGs, updated tests --- src/utils.jl | 96 ++++++++++++++++++++++++++------------------------- test/utils.jl | 14 +++++--- 2 files changed, 59 insertions(+), 51 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 92c1b76d6a..1b4feec00a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -39,7 +39,7 @@ epseltype(x) = eps(float(eltype(x))) Create an instance of the RNG most appropriate for `x`. The current defaults are: - `x isa AbstractArray` - - Julia version is < 1.7: `Random.GLOBAL_RNG` + - Julia version is < 1.7: `rng_from_array()` - Julia version is >= 1.7: `Random.default_rng()` - `x isa CuArray`: `CUDA.default_rng()` When `x` is unspecified, it is assumed to be a `AbstractArray`. @@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3) [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...))) -glorot_uniform(dims...) = glorot_uniform(Random.GLOBAL_RNG, dims...) +glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...) glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...) """ @@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2) [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...))) -glorot_normal(dims...) = glorot_normal(Random.GLOBAL_RNG, dims...) +glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...) glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...) """ @@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound end -kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...) +kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...) kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) """ @@ -188,9 +188,50 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0) return randn(rng, Float32, dims...) .* std end -kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...) +kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...) kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +""" + truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.) + +Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution. +The values are generated by using a truncated uniform distribution and then using the inverse CDF +for the normal distribution. The method used for generating the random values works best when +`lo ≤ mean ≤ hi`. + +# Examples +```jldoctest; setup = :(using Random; Random.seed!(0)) +julia> Flux.truncated_normal(3, 2) +3×2 Matrix{Float32}: + -0.0340547 -1.35207 + -0.22757 -0.793773 + -1.75771 1.01801 +``` + +# References +[1] Burkardt, John. "The Truncated Normal Distribution" +[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf). +Department of Scientific Computing website. +""" +function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2) + norm_cdf(x) = 0.5 * (1 + erf(x/√2)) + if (mean < lo - 2 * std) || (mean > hi + 2 * std) + @warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1 + end + l = norm_cdf((lo - mean) / std) + u = norm_cdf((hi - mean) / std) + xs = rand(rng, Float32, dims...) + broadcast!(xs, xs) do x + x = x * 2(u - l) + (2l - 1) + x = erfinv(x) + x = clamp.(x * std * √2f0 + mean, lo, hi) + end + return xs +end + +truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...) +truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) + """ orthogonal([rng=GLOBAL_RNG], dims...; gain = 1) @@ -255,7 +296,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...) return reshape(orthogonal(rng, rows, cols; kwargs...), dims) end -orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...) +orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...) orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) """ @@ -299,48 +340,9 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01) return mapslices(shuffle, sparse_array, dims=1) end -sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...) +sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...) sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) -""" - truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.) - -Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution. -The values are generated by using a truncated uniform distribution and then using the inverse CDF -for the normal distribution. The method used for generating the random values works best when -`lo ≤ mean ≤ hi`. - -# Examples -```jldoctest; setup = :(using Random; Random.seed!(0)) -julia> Flux.truncated_normal(3, 2) -3×2 Matrix{Float32}: - -0.113785 -0.627307 - -0.676033 0.198423 - 0.509005 -0.554339 -``` - -# References -[1] Burkardt, John. "The Truncated Normal Distribution" -[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf). -Department of Scientific Computing website. -""" -function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2) - norm_cdf(x) = 0.5 * (1 + erf(x/√2)) - if (mean < lo - 2 * std) || (mean > hi + 2 * std) - @warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1 - end - l = norm_cdf((lo - mean) / std) - u = norm_cdf((hi - mean) / std) - x = rand(rng, dims...) * 2(u - l) .+ (2l - 1) - x = erfinv.(x) - x = f32(x .* std * √2 .+ mean) - return x -end - -truncated_normal(dims::Integer...; kwargs...) = truncated_normal(Random.GLOBAL_RNG, dims...; kwargs...) -truncated_normal(dims) = truncated_normal(Random.GLOBAL_RNG, dims...) -truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) - """ identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0) @@ -422,7 +424,7 @@ function identity_init(dims...; gain=1, shift=0) end identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...) -identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...) +identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...) identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) ones32(dims...) = Base.ones(Float32, dims...) diff --git a/test/utils.jl b/test/utils.jl index fdf589c07f..4481d7f607 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -2,8 +2,9 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, sparse_init, stack, unstack, Zeros, batch, unbatch, - unsqueeze + unsqueeze, params using StatsBase: var, std +using Statistics, LinearAlgebra using Random using Test @@ -149,10 +150,15 @@ end @testset "truncated_normal" begin size = (100, 100, 100) for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)] + v = truncated_normal(size; mean = μ, std = σ, lo, hi) + @test isapprox(mean(v), μ; atol = 1f-2) + @test isapprox(minimum(v), lo; atol = 1f-2) + @test isapprox(maximum(v), hi; atol = 1f-2) + @test eltype(v) == Float32 + end + for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)] v = truncated_normal(size...; mean = μ, std = σ, lo, hi) - @test isapprox(mean(v), μ; atol = 1e-2) - @test isapprox(minimum(v), lo; atol = 1e-2) - @test isapprox(maximum(v), hi; atol = 1e-2) + @test isapprox(std(v), σ; atol = 1f-2) @test eltype(v) == Float32 end end From ebc69a9cf6f35c16c0e77e079fe75440177cfe73 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 18 Feb 2022 09:06:32 +0530 Subject: [PATCH 06/12] Apply suggestions from code review Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/utils.jl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 1b4feec00a..662425150d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -192,12 +192,13 @@ kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; k kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) """ - truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.) + truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2) -Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution. -The values are generated by using a truncated uniform distribution and then using the inverse CDF -for the normal distribution. The method used for generating the random values works best when -`lo ≤ mean ≤ hi`. +Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution. +The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(10^6))`. + +The values are generated using uniform `rand()` and then the inverse CDF. +This method works best when `lo ≤ mean ≤ hi`. # Examples ```jldoctest; setup = :(using Random; Random.seed!(0)) @@ -216,7 +217,7 @@ Department of Scientific Computing website. function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2) norm_cdf(x) = 0.5 * (1 + erf(x/√2)) if (mean < lo - 2 * std) || (mean > hi + 2 * std) - @warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1 + @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 end l = norm_cdf((lo - mean) / std) u = norm_cdf((hi - mean) / std) @@ -224,7 +225,7 @@ function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, broadcast!(xs, xs) do x x = x * 2(u - l) + (2l - 1) x = erfinv(x) - x = clamp.(x * std * √2f0 + mean, lo, hi) + x = clamp(x * std * √2 + mean, lo, hi) end return xs end @@ -493,7 +494,7 @@ The second return value `re` allows you to reconstruct the original network afte modifications to the weight vector (for example, with a hypernetwork). julia> re(θ .* 2) - Chain(Dense(10, 5, std), Dense(5, 2), softmax) + Chain(Dense(10, 5, σ), Dense(5, 2), softmax) """ function destructure(m) xs = Zygote.Buffer([]) From 81bd8eca5454eb2397deaae71f0de7911060340f Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 19 Feb 2022 08:33:54 +0530 Subject: [PATCH 07/12] Apply suggestions from code review Co-authored-by: Kyle Daruwalla --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 662425150d..eeed4d8ee2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -197,7 +197,9 @@ kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaimi Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution. The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(10^6))`. -The values are generated using uniform `rand()` and then the inverse CDF. +The values are generated by sampling a Uniform(0, 1) (`rand()`) and then +applying the inverse CDF of the truncated normal distribution +(see the references for more info). This method works best when `lo ≤ mean ≤ hi`. # Examples From f030521d5a16a4a2ff10202d5f13babafbd33812 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 19 Feb 2022 08:34:54 +0530 Subject: [PATCH 08/12] Update utils.jl --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index eeed4d8ee2..35eec9705b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -39,7 +39,7 @@ epseltype(x) = eps(float(eltype(x))) Create an instance of the RNG most appropriate for `x`. The current defaults are: - `x isa AbstractArray` - - Julia version is < 1.7: `rng_from_array()` + - Julia version is < 1.7: `Random.GLOBAL_RNG` - Julia version is >= 1.7: `Random.default_rng()` - `x isa CuArray`: `CUDA.default_rng()` When `x` is unspecified, it is assumed to be a `AbstractArray`. From 8403e312018c4d0b51e16184bb31aedb0779dc48 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 19 Feb 2022 09:03:18 +0530 Subject: [PATCH 09/12] Increased tolerance, modified doctests Multiple modifications to doctrings --- src/utils.jl | 17 ++++++++++------- test/utils.jl | 8 ++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 35eec9705b..096de4fba0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -195,7 +195,7 @@ kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaimi truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2) Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution. -The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(10^6))`. +The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(dims...))`. The values are generated by sampling a Uniform(0, 1) (`rand()`) and then applying the inverse CDF of the truncated normal distribution @@ -203,12 +203,15 @@ applying the inverse CDF of the truncated normal distribution This method works best when `lo ≤ mean ≤ hi`. # Examples -```jldoctest; setup = :(using Random; Random.seed!(0)) -julia> Flux.truncated_normal(3, 2) -3×2 Matrix{Float32}: - -0.0340547 -1.35207 - -0.22757 -0.793773 - -1.75771 1.01801 +```jldoctest; setup = :(using Statistics) +julia> Flux.truncated_normal(3, 4) |> summary +"3×4 Matrix{Float32}" + +julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3) +(-2.0f0, 2.0f0) + +julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100))) +1.0f0 ``` # References diff --git a/test/utils.jl b/test/utils.jl index 4481d7f607..3f4efe24ba 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -151,14 +151,14 @@ end size = (100, 100, 100) for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)] v = truncated_normal(size; mean = μ, std = σ, lo, hi) - @test isapprox(mean(v), μ; atol = 1f-2) - @test isapprox(minimum(v), lo; atol = 1f-2) - @test isapprox(maximum(v), hi; atol = 1f-2) + @test isapprox(mean(v), μ; atol = 1f-1) + @test isapprox(minimum(v), lo; atol = 1f-1) + @test isapprox(maximum(v), hi; atol = 1f-1) @test eltype(v) == Float32 end for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)] v = truncated_normal(size...; mean = μ, std = σ, lo, hi) - @test isapprox(std(v), σ; atol = 1f-2) + @test isapprox(std(v), σ; atol = 1f-1) @test eltype(v) == Float32 end end From e0b74ed2df14d18e7e6eef71bc104aac8c79ebbd Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 19 Feb 2022 10:01:40 +0530 Subject: [PATCH 10/12] Update src/utils.jl Co-authored-by: Kyle Daruwalla --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 096de4fba0..19b020704b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -203,7 +203,9 @@ applying the inverse CDF of the truncated normal distribution This method works best when `lo ≤ mean ≤ hi`. # Examples -```jldoctest; setup = :(using Statistics) +```jldoctest +julia> using Statistics + julia> Flux.truncated_normal(3, 4) |> summary "3×4 Matrix{Float32}" From 6ed3a71a27e5a8ee8eaf5f8955290cabe34d7e8c Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 19 Feb 2022 10:41:20 +0530 Subject: [PATCH 11/12] Update NEWS.md --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 54959f6902..09e68e8702 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,6 +9,7 @@ been removed in favour of MLDatasets.jl. * `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable. * Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874). * The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`. +* Added truncated normal initialisation of weights. ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838) From 423a9965dfe6915474b998df8531ce93b4a67b93 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 19 Feb 2022 19:57:06 +0530 Subject: [PATCH 12/12] Update NEWS.md Co-authored-by: Kyle Daruwalla --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 09e68e8702..278e006b2d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,7 +9,7 @@ been removed in favour of MLDatasets.jl. * `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable. * Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874). * The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`. -* Added truncated normal initialisation of weights. +* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights. ## v0.12.10 * `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)