From 5c358d0d8caf377d8b376594241f498a5331e6a5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 2 Feb 2024 12:51:59 -0500 Subject: [PATCH] allow scaling of init functions --- docs/src/utilities.md | 16 ++++++++++ src/utils.jl | 71 ++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 83 insertions(+), 4 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index bc1200124e..04859a8c0b 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -27,6 +27,22 @@ julia> Dense(4 => 5, tanh; init=Flux.randn32(MersenneTwister(1))) Dense(4 => 5, tanh) # 25 parameters ``` +All of the initialisation functions may be multiplied by a number +to scale their output: + +```jldoctest; setup = :(using Flux, Random) +julia> lay = Dense(3 => 1, relu; init=42*Flux.ones32) +Dense(3 => 1, relu) # 4 parameters + +julia> lay.weight +1×3 Matrix{Float32}: + 42.0 42.0 42.0 + +julia> lay.bias +1-element Vector{Float32}: + 0.0 +``` + ## Initialisation functions ```@docs diff --git a/src/utils.jl b/src/utils.jl index 082d9dcb1c..15a4ecbe5b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -46,6 +46,57 @@ rng_from_array(::AbstractArray) = Random.default_rng() @non_differentiable rng_from_array(::Any) +""" + FixRNG(init, [rng, scale; kw...]) + +This is a bit like `Base.Fix1` in that `FixRNG(randn, rng)` makes a function, +but also allows for scaling by a factor. + +It exists to allow modifying initialisation functions: +``` +julia> 2 * randn32 +Flux.FixRNG(randn32, 2.0) + +julia> Dense(3 => 1, init=pi*ones32) +Dense(3 => 1) # 4 parameters + +julia> ans.weight +1×3 Matrix{Float32}: + 3.14159 3.14159 3.14159 +``` + +The struct itself is not part of Flux's API, so using it directly is not recommended. +""" +struct FixRNG{F<:Function, R<:Tuple, K<:NamedTuple} <: Function + fun::F + args::R + kwargs::K + scale::Float32 +end +FixRNG(f::Function, scale::Real=1f0; kw...) = FixRNG(f, (), NamedTuple(kw), scale) +FixRNG(f::Function, rng::AbstractRNG, scale::Real=1f0; kw...) = FixRNG(f, (rng,), NamedTuple(kw), scale) + +function (init::FixRNG)(args...; kw...) + raw = init.fun(init.args..., args...; kw...) + if isone(init.scale) + return raw + elseif raw isa Array{<:AbstractFloat} + return lmul!(init.scale, raw) # premature optimisation to save alloc! + else + return @. oftype(float(raw), init.scale * raw) + end +end + +Base.:*(λ::Real, init::FixRNG) = FixRNG(init.fun, init.args, init.kwargs, Float32(λ * init.scale)) + +function Base.show(io::IO, init::FixRNG) + print(io, "Flux.FixRNG(", init.fun) + isempty(init.args) || print(io, ", ", join(init.args, ", ")) + isone(init.scale) || print(io, ", ", init.scale) + print(io, ")") +end +Base.show(io::IO, ::MIME"text/plain", init::FixRNG) = Base.show(io, init) # needed because of <:Function + """ glorot_uniform([rng], size...; gain = 1) -> Array glorot_uniform([rng]; kw...) -> Function @@ -87,7 +138,7 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) (rand(rng, Float32, dims...) .- 0.5f0) .* scale end glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng(), dims...; kw...) -glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +glorot_uniform(rng::AbstractRNG=default_rng(); init_kwargs...) = FixRNG(glorot_uniform, rng; init_kwargs...) ChainRulesCore.@non_differentiable glorot_uniform(::Any...) @@ -130,7 +181,7 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) randn(rng, Float32, dims...) .* std end glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng(), dims...; kwargs...) -glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +glorot_normal(rng::AbstractRNG=default_rng(); init_kwargs...) = FixRNG(glorot_normal, rng; init_kwargs...) ChainRulesCore.@non_differentiable glorot_normal(::Any...) @@ -455,6 +506,9 @@ ChainRulesCore.@non_differentiable identity_init(::Any...) ones32(size...) = ones(Float32, size...) Return an `Array{Float32}` of the given `size` filled with 1s. + +Multiplying by a number scales the output. Thus `init = 10 * ones32` is a function +with makes an array with all values `10f0`. """ ones32(dims...) = Base.ones(Float32, dims...) @@ -473,17 +527,26 @@ When the size is not provided, `rand32(rng::AbstractRNG)` returns a function. """ rand32(dims::Integer...) = Base.rand(Float32, dims...) rand32(rng::AbstractRNG, dims::Integer...) = Base.rand(rng, Float32, dims...) -rand32(rng::AbstractRNG) = (dims...,) -> Base.rand(rng, Float32, dims...) +rand32(rng::AbstractRNG) = FixRNG(rand32, rng) """ randn32([rng], size...) Return an `Array{Float32}` of the given `size`, filled like `randn`. When the size is not provided, `randn32(rng::AbstractRNG)` returns a function. + +Multiplying by a number scales the output. Thus `init = 10 * randn32` is a function +with makes an array of mean zero and standard deviation 10. """ randn32(dims::Integer...) = Base.randn(Float32, dims...) randn32(rng::AbstractRNG, dims::Integer...) = Base.randn(rng, Float32, dims...) -randn32(rng::AbstractRNG) = (dims...,) -> Base.randn(rng, Float32, dims...) +randn32(rng::AbstractRNG) = FixRNG(randn32, rng) + +for fun in [ones32, zeros32, rand32, randn32] + @eval Base.:*(λ::Real, fun::typeof($fun)) = λ * FixRNG($fun) + @eval Base.:*(fun::typeof($fun), λ::Real) = λ * FixRNG($fun) + @eval Base.:/(fun::typeof($fun), λ::Real) = (1/λ) * FixRNG($fun) +end """ create_bias(weights, bias, size...)