Skip to content

Commit

Permalink
Merge pull request #1877 from theabhirath/trunc-normal
Browse files Browse the repository at this point in the history
Truncated normal initialisation for weights
  • Loading branch information
darsnack authored Feb 19, 2022
2 parents 13a65be + 423a996 commit b35b23b
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 12 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ been removed in favour of MLDatasets.jl.
* `Dropout` gained improved compatibility with Int and Complex arrays and is now twice-differentiable.
* Many utily functions and the `DataLoader` are [now provided by MLUtils.jl](https://github.com/FluxML/Flux.jl/pull/1874).
* The DataLoader is now compatible with generic dataset types implementing `MLUtils.numobs` and `MLUtils.getobs`.
* Added [truncated normal initialisation](https://github.com/FluxML/Flux.jl/pull/1877) of weights.

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Flux
# Zero Flux Given

using Base: tail
using Statistics, Random, LinearAlgebra
using Statistics, Random, LinearAlgebra, SpecialFunctions
using Zygote, MacroTools, ProgressLogging, Reexport
using MacroTools: @forward
@reexport using NNlib
Expand Down
68 changes: 59 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims...) = glorot_uniform(Random.GLOBAL_RNG, dims...)
glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...)

"""
Expand Down Expand Up @@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims...) = glorot_normal(Random.GLOBAL_RNG, dims...)
glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...)

"""
Expand Down Expand Up @@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -188,9 +188,58 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)
Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(dims...))`.
The values are generated by sampling a Uniform(0, 1) (`rand()`) and then
applying the inverse CDF of the truncated normal distribution
(see the references for more info).
This method works best when `lo ≤ mean ≤ hi`.
# Examples
```jldoctest
julia> using Statistics
julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)
julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0
```
# References
[1] Burkardt, John. "The Truncated Normal Distribution"
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1
end
l = norm_cdf((lo - mean) / std)
u = norm_cdf((hi - mean) / std)
xs = rand(rng, Float32, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = erfinv(x)
x = clamp(x * std * 2 + mean, lo, hi)
end
return xs
end

truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
Expand Down Expand Up @@ -232,6 +281,7 @@ true
* sparse initialization: [`sparse_init`](@ref Flux.sparse_init)
# References
[1] Saxe, McClelland, Ganguli. "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks", ICLR 2014, https://arxiv.org/abs/1312.6120
"""
Expand All @@ -254,7 +304,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
end

orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -298,7 +348,7 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -382,7 +432,7 @@ function identity_init(dims...; gain=1, shift=0)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...)
identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ones32(dims...) = Base.ones(Float32, dims...)
Expand Down Expand Up @@ -437,8 +487,8 @@ end
Flatten a model's parameters into a single weight vector.
julia> m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
julia> m = Chain(Dense(10, 5, std), Dense(5, 2), softmax)
Chain(Dense(10, 5, std), Dense(5, 2), softmax)
julia> θ, re = destructure(m);
Expand Down
21 changes: 19 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, Zeros, batch, unbatch,
unsqueeze
unsqueeze, params
using StatsBase: var, std
using Statistics, LinearAlgebra
using Random
using Test

Expand Down Expand Up @@ -146,6 +147,22 @@ end
end
end

@testset "truncated_normal" begin
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(minimum(v), lo; atol = 1f-1)
@test isapprox(maximum(v), hi; atol = 1f-1)
@test eltype(v) == Float32
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
@test isapprox(std(v), σ; atol = 1f-1)
@test eltype(v) == Float32
end
end

@testset "partial_application" begin
big = 1e9

Expand Down

0 comments on commit b35b23b

Please sign in to comment.