Skip to content

Commit

Permalink
allow scaling of init functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 2, 2024
1 parent 674ed1a commit 5c358d0
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 4 deletions.
16 changes: 16 additions & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1)))
Dense(4 => 5, tanh) # 25 parameters
```

All of the initialisation functions may be multiplied by a number
to scale their output:

```jldoctest; setup = :(using Flux, Random)
julia> lay = Dense(3 => 1, relu; init=42*Flux.ones32)
Dense(3 => 1, relu) # 4 parameters
julia> lay.weight
1×3 Matrix{Float32}:
42.0 42.0 42.0
julia> lay.bias
1-element Vector{Float32}:
0.0
```

## Initialisation functions

```@docs
Expand Down
71 changes: 67 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,57 @@ rng_from_array(::AbstractArray) = Random.default_rng()
@non_differentiable rng_from_array(::Any)


"""
FixRNG(init, [rng, scale; kw...])
This is a bit like `Base.Fix1` in that `FixRNG(randn, rng)` makes a function,
but also allows for scaling by a factor.
It exists to allow modifying initialisation functions:
```
julia> 2 * randn32
Flux.FixRNG(randn32, 2.0)
julia> Dense(3 => 1, init=pi*ones32)
Dense(3 => 1) # 4 parameters
julia> ans.weight
1×3 Matrix{Float32}:
3.14159 3.14159 3.14159
```
The struct itself is not part of Flux's API, so using it directly is not recommended.
"""
struct FixRNG{F<:Function, R<:Tuple, K<:NamedTuple} <: Function
fun::F
args::R
kwargs::K
scale::Float32
end
FixRNG(f::Function, scale::Real=1f0; kw...) = FixRNG(f, (), NamedTuple(kw), scale)
FixRNG(f::Function, rng::AbstractRNG, scale::Real=1f0; kw...) = FixRNG(f, (rng,), NamedTuple(kw), scale)

function (init::FixRNG)(args...; kw...)
raw = init.fun(init.args..., args...; kw...)
if isone(init.scale)
return raw
elseif raw isa Array{<:AbstractFloat}
return lmul!(init.scale, raw) # premature optimisation to save alloc!
else
return @. oftype(float(raw), init.scale * raw)
end
end

Base.:*::Real, init::FixRNG) = FixRNG(init.fun, init.args, init.kwargs, Float32* init.scale))

function Base.show(io::IO, init::FixRNG)
print(io, "Flux.FixRNG(", init.fun)
isempty(init.args) || print(io, ", ", join(init.args, ", "))
isone(init.scale) || print(io, ", ", init.scale)
print(io, ")")
end
Base.show(io::IO, ::MIME"text/plain", init::FixRNG) = Base.show(io, init) # needed because of <:Function

"""
glorot_uniform([rng], size...; gain = 1) -> Array
glorot_uniform([rng]; kw...) -> Function
Expand Down Expand Up @@ -87,7 +138,7 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1)
(rand(rng, Float32, dims...) .- 0.5f0) .* scale
end
glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng(), dims...; kw...)
glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...)
glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = FixRNG(glorot_uniform, rng; init_kwargs...)

ChainRulesCore.@non_differentiable glorot_uniform(::Any...)

Expand Down Expand Up @@ -130,7 +181,7 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1)
randn(rng, Float32, dims...) .* std
end
glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng(), dims...; kwargs...)
glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...)
glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = FixRNG(glorot_normal, rng; init_kwargs...)

ChainRulesCore.@non_differentiable glorot_normal(::Any...)

Expand Down Expand Up @@ -455,6 +506,9 @@ ChainRulesCore.@non_differentiable identity_init(::Any...)
ones32(size...) = ones(Float32, size...)
Return an `Array{Float32}` of the given `size` filled with 1s.
Multiplying by a number scales the output. Thus `init = 10 * ones32` is a function
with makes an array with all values `10f0`.
"""
ones32(dims...) = Base.ones(Float32, dims...)

Expand All @@ -473,17 +527,26 @@ When the size is not provided, `rand32(rng::AbstractRNG)` returns a function.
"""
rand32(dims::Integer...) = Base.rand(Float32, dims...)
rand32(rng::AbstractRNG, dims::Integer...) = Base.rand(rng, Float32, dims...)
rand32(rng::AbstractRNG) = (dims...,) -> Base.rand(rng, Float32, dims...)
rand32(rng::AbstractRNG) = FixRNG(rand32, rng)

"""
randn32([rng], size...)
Return an `Array{Float32}` of the given `size`, filled like `randn`.
When the size is not provided, `randn32(rng::AbstractRNG)` returns a function.
Multiplying by a number scales the output. Thus `init = 10 * randn32` is a function
with makes an array of mean zero and standard deviation 10.
"""
randn32(dims::Integer...) = Base.randn(Float32, dims...)
randn32(rng::AbstractRNG, dims::Integer...) = Base.randn(rng, Float32, dims...)
randn32(rng::AbstractRNG) = (dims...,) -> Base.randn(rng, Float32, dims...)
randn32(rng::AbstractRNG) = FixRNG(randn32, rng)

for fun in [ones32, zeros32, rand32, randn32]
@eval Base.:*::Real, fun::typeof($fun)) = λ * FixRNG($fun)
@eval Base.:*(fun::typeof($fun), λ::Real) = λ * FixRNG($fun)
@eval Base.:/(fun::typeof($fun), λ::Real) = (1/λ) * FixRNG($fun)
end

"""
create_bias(weights, bias, size...)
Expand Down

0 comments on commit 5c358d0

Please sign in to comment.