Skip to content

Commit

Permalink
rng_from_array() -> default_rng_value()
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Aug 6, 2022
1 parent 9ab34a7 commit 91013ff
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
2 changes: 2 additions & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Flux.ones32
Flux.zeros32
Flux.rand32
Flux.randn32
Flux.rng_from_array
Flux.default_rng_value
```

## Changing the type of model parameters
Expand Down
12 changes: 6 additions & 6 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)

"""
Dropout(p; dims=:, rng = _rng_from_array())
Dropout(p; dims=:, rng = default_rng_value())
Dropout layer.
Expand Down Expand Up @@ -96,9 +96,9 @@ mutable struct Dropout{F,D,R<:AbstractRNG}
active::Union{Bool, Nothing}
rng::R
end
Dropout(p, dims, active) = Dropout(p, dims, active, _rng_from_array())
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value())

function Dropout(p; dims=:, rng = _rng_from_array())
function Dropout(p; dims=:, rng = default_rng_value())
@assert 0 p 1
Dropout(p, dims, nothing, rng)
end
Expand All @@ -121,7 +121,7 @@ function Base.show(io::IO, d::Dropout)
end

"""
AlphaDropout(p; rng = _rng_from_array())
AlphaDropout(p; rng = default_rng_value())
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
Expand Down Expand Up @@ -155,8 +155,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
new{typeof(p), typeof(rng)}(p, active, rng)
end
end
AlphaDropout(p, active) = AlphaDropout(p, active, _rng_from_array())
AlphaDropout(p; rng = _rng_from_array()) = AlphaDropout(p, nothing, rng)
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
Expand Down
58 changes: 33 additions & 25 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,24 @@ The current defaults are:
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is >= 1.7: `Random.default_rng()`
"""
_rng_from_array(::AbstractArray) = _rng_from_array()
_rng_from_array(::CuArray) = CUDA.default_rng()
rng_from_array(::AbstractArray) = default_rng_value()
rng_from_array(::CuArray) = CUDA.default_rng()

if VERSION >= v"1.7"
_rng_from_array() = Random.default_rng()
@doc """
default_rng_value()
Create an instance of the default RNG depending on Julia's version.
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is >= 1.7: `Random.default_rng()`
"""
default_rng_value() = Random.default_rng()
else
_rng_from_array() = Random.GLOBAL_RNG
default_rng_value() = Random.GLOBAL_RNG
end

"""
glorot_uniform([rng=GLOBAL_RNG], size...; gain = 1) -> Array
glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array
glorot_uniform([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform
Expand Down Expand Up @@ -91,13 +99,13 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
scale = Float32(gain) * sqrt(24.0f0 / sum(nfan(dims...)))
(rand(rng, Float32, dims...) .- 0.5f0) .* scale
end
glorot_uniform(dims::Integer...; kw...) = glorot_uniform(_rng_from_array(), dims...; kw...)
glorot_uniform(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...)
glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng_value(), dims...; kw...)
glorot_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable glorot_uniform(::Any...)

"""
glorot_normal([rng=GLOBAL_RNG], size...; gain = 1) -> Array
glorot_normal([rng = default_rng_value(), size...; gain = 1) -> Array
glorot_normal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal
Expand Down Expand Up @@ -134,13 +142,13 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
std = Float32(gain) * sqrt(2.0f0 / sum(nfan(dims...)))
randn(rng, Float32, dims...) .* std
end
glorot_normal(dims::Integer...; kwargs...) = glorot_normal(_rng_from_array(), dims...; kwargs...)
glorot_normal(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...)
glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng_value(), dims...; kwargs...)
glorot_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable glorot_normal(::Any...)

"""
kaiming_uniform([rng=GLOBAL_RNG], size...; gain = √2) -> Array
kaiming_uniform([rng = default_rng_value()], size...; gain = √2) -> Array
kaiming_uniform([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform distribution
Expand Down Expand Up @@ -169,13 +177,13 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2)
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(_rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)
kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(default_rng_value(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable kaiming_uniform(::Any...)

"""
kaiming_normal([rng=GLOBAL_RNG], size...; gain = √2) -> Array
kaiming_normal([rng = default_rng_value()], size...; gain = √2) -> Array
kaiming_normal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal
Expand Down Expand Up @@ -206,13 +214,13 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0)
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(_rng_from_array(), dims...; kwargs...)
kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(default_rng_value(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable kaiming_normal(::Any...)

"""
truncated_normal([rng=GLOBAL_RNG], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
Expand Down Expand Up @@ -252,13 +260,13 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1,
return xs
end

truncated_normal(dims::Integer...; kwargs...) = truncated_normal(_rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)
truncated_normal(dims::Integer...; kwargs...) = truncated_normal(default_rng_value(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable truncated_normal(::Any...)

"""
orthogonal([rng=GLOBAL_RNG], size...; gain = 1) -> Array
orthogonal([rng = default_rng_value()], size...; gain = 1) -> Array
orthogonal([rng]; kw...) -> Function
Return an `Array{Float32}` of the given `size` which is a (semi) orthogonal matrix, as described in [1].
Expand Down Expand Up @@ -313,13 +321,13 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
end

orthogonal(dims::Integer...; kwargs...) = orthogonal(_rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)
orthogonal(dims::Integer...; kwargs...) = orthogonal(default_rng_value(), dims...; kwargs...)
orthogonal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable orthogonal(::Any...)

"""
sparse_init([rng=GLOBAL_RNG], rows, cols; sparsity, std = 0.01) -> Array
sparse_init([rng = default_rng_value()], rows, cols; sparsity, std = 0.01) -> Array
sparse_init([rng]; kw...) -> Function
Return a `Matrix{Float32}` of size `rows, cols` where each column contains a fixed fraction of
Expand Down Expand Up @@ -361,8 +369,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims::Integer...; kwargs...) = sparse_init(_rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)
sparse_init(dims::Integer...; kwargs...) = sparse_init(default_rng_value(), dims...; kwargs...)
sparse_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable sparse_init(::Any...)

Expand Down Expand Up @@ -452,7 +460,7 @@ end

# For consistency, it accepts an RNG, but ignores it:
identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(rng::AbstractRNG=_rng_from_array(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)
identity_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable identity_init(::Any...)

Expand Down

0 comments on commit 91013ff

Please sign in to comment.