Skip to content

Commit

Permalink
Support for lecun normal weight initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
RohitRathore1 committed Aug 11, 2023
1 parent 1348828 commit 56f36c2
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,48 @@ truncated_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwa

ChainRulesCore.@non_differentiable truncated_normal(::Any...)

"""
lecun_normal([rng], size...) -> Array
lecun_normal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a truncated normal
distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units
in the weight tensor.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> using Statistics
julia> round(std(Flux.lecun_normal(10, 1000)), digits=3)
0.032f0
julia> round(std(Flux.lecun_normal(1000, 10)), digits=3)
0.317f0
julia> round(std(Flux.lecun_normal(1000, 1000)), digits=3)
0.032f0
julia> Dense(10 => 1000, selu; init = Flux.lecun_normal())
Dense(10 => 1000, selu) # 11_000 parameters
julia> round(std(ans.weight), sigdigits=3)
0.319f0
```
# References
[1] Lecun, Yann, et al. "Efficient backprop." Neural networks: Tricks of the trade. Springer, Berlin, Heidelberg, 2012. 9-48.
"""
function lecun_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
std = Float32(gain)*sqrt(1.0f0 / first(nfan(dims...))) # calculates the standard deviation based on the `fan_in` value
return truncated_normal(rng, dims...; mean=0, std=std)

Check warning on line 289 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L287-L289

Added lines #L287 - L289 were not covered by tests
end

lecun_normal(dims::Integer...; kwargs...) = lecun_normal(default_rng(), dims...; kwargs...)
lecun_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> lecun_normal(rng, dims...; init_kwargs..., kwargs...)

Check warning on line 293 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L292-L293

Added lines #L292 - L293 were not covered by tests

ChainRulesCore.@non_differentiable lecun_normal(::Any...)

"""
orthogonal([rng], size...; gain = 1) -> Array
orthogonal([rng]; kw...) -> Function
Expand Down

0 comments on commit 56f36c2

Please sign in to comment.