diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 6e6226a45f..28f6bc4a18 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -42,7 +42,11 @@ Flux.orthogonal Flux.sparse_init Flux.identity_init Flux.ones32 +Flux.zeros32 Flux.rand32 +Flux.randn32 +Flux.rng_from_array +Flux.default_rng_value ``` ## Changing the type of model parameters diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 42813cb5f7..51c5fda9b1 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -155,7 +155,7 @@ struct Dense{F, M<:AbstractMatrix, B} bias::B σ::F function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F} - b = create_bias(W, bias, size(W,1)) + b = _create_bias(W, bias, size(W,1)) new{F,M,typeof(b)}(W, b, σ) end end @@ -228,7 +228,7 @@ struct Scale{F, A<:AbstractArray, B} bias::B σ::F function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F} - b = create_bias(scale, bias, size(scale)...) + b = _create_bias(scale, bias, size(scale)...) new{F, A, typeof(b)}(scale, b, σ) end end @@ -403,7 +403,7 @@ struct Bilinear{F,A,B} σ::F function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F} ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights")) - b = create_bias(W, bias, size(W,1)) + b = _create_bias(W, bias, size(W,1)) new{F,A,typeof(b)}(W, b, σ) end end diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 003395c15d..36aa5c8430 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -156,7 +156,7 @@ function Conv(w::AbstractArray{T,N}, b = true, σ = identity; stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride) - bias = create_bias(w, b, size(w, N)) + bias = _create_bias(w, b, size(w, N)) return Conv(σ, w, bias, stride, pad, dilation, groups) end @@ -293,7 +293,7 @@ function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity; stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride) - b = create_bias(w, bias, size(w, N-1) * groups) + b = _create_bias(w, bias, size(w, N-1) * groups) return ConvTranspose(σ, w, b, stride, pad, dilation, groups) end @@ -441,7 +441,7 @@ function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity; stride = expand(Val(N-2), stride) dilation = expand(Val(N-2), dilation) pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride) - b = create_bias(w, bias, size(w, N)) + b = _create_bias(w, bias, size(w, N)) return CrossCor(σ, w, b, stride, pad, dilation) end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index f1f6c22033..446575f355 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -51,7 +51,7 @@ end ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any) """ - Dropout(p; dims=:, rng = rng_from_array()) + Dropout(p; dims=:, rng = default_rng_value()) Dropout layer. @@ -96,9 +96,9 @@ mutable struct Dropout{F,D,R<:AbstractRNG} active::Union{Bool, Nothing} rng::R end -Dropout(p, dims, active) = Dropout(p, dims, active, rng_from_array()) +Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value()) -function Dropout(p; dims=:, rng = rng_from_array()) +function Dropout(p; dims=:, rng = default_rng_value()) @assert 0 ≤ p ≤ 1 Dropout(p, dims, nothing, rng) end @@ -121,7 +121,7 @@ function Base.show(io::IO, d::Dropout) end """ - AlphaDropout(p; rng = rng_from_array()) + AlphaDropout(p; rng = default_rng_value()) A dropout layer. Used in [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515). @@ -155,8 +155,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG} new{typeof(p), typeof(rng)}(p, active, rng) end end -AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array()) -AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng) +AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value()) +AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng) @functor AlphaDropout trainable(a::AlphaDropout) = (;) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 34e365ae9d..06c8b6a4a9 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -32,6 +32,25 @@ end Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`. Per default, `dims` is the last dimension. `ϵ` is a small additive factor added to the denominator for numerical stability. + +# Examples +```jldoctest +julia> using Statistics + +julia> x = [9, 10, 20, 60]; + +julia> y = Flux.normalise(x); + +julia> isapprox(std(y), 1, atol=0.2) && std(y) != std(x) +true + +julia> x = rand(1:100, 10, 2); + +julia> y = Flux.normalise(x, dims=1); + +julia> isapprox(std(y, dims=1), ones(1, 2), atol=0.2) && std(y, dims=1) != std(x, dims=1) +true +``` """ @inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) μ = mean(x, dims=dims) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 532d9f3dfb..1bb14b2e74 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -80,6 +80,17 @@ given the prediction `ŷ` and true values `y`. | 0.5 * |ŷ - y|^2, for |ŷ - y| <= δ Huber loss = | | δ * (|ŷ - y| - 0.5 * δ), otherwise + +# Example +```jldoctest +julia> ŷ = [1.1, 2.1, 3.1]; + +julia> Flux.huber_loss(ŷ, 1:3) # default δ = 1 > |ŷ - y| +0.005000000000000009 + +julia> Flux.huber_loss(ŷ, 1:3, δ=0.05) # changes behaviour as |ŷ - y| > δ +0.003750000000000005 +``` """ function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1)) _check_sizes(ŷ, y) @@ -377,12 +388,22 @@ function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) end """ - poisson_loss(ŷ, y) + poisson_loss(ŷ, y; agg = mean) -# Return how much the predicted distribution `ŷ` diverges from the expected Poisson -# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`. +Return how much the predicted distribution `ŷ` diverges from the expected Poisson +distribution `y`; calculated as - + + sum(ŷ .- y .* log.(ŷ)) / size(y, 2) [More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson). + +# Example +```jldoctest +julia> y_model = [1, 3, 3]; # data should only take integral values + +julia> Flux.poisson_loss(y_model, 1:3) +0.5023128522198171 +``` """ function poisson_loss(ŷ, y; agg = mean) _check_sizes(ŷ, y) @@ -392,11 +413,32 @@ end """ hinge_loss(ŷ, y; agg = mean) -Return the [hinge_loss loss](https://en.wikipedia.org/wiki/Hinge_loss) given the +Return the [hinge_loss](https://en.wikipedia.org/wiki/Hinge_loss) given the prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as -`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`. + sum(max.(0, 1 .- ŷ .* y)) / size(y, 2) + +Usually used with classifiers like Support Vector Machines. See also: [`squared_hinge_loss`](@ref) + +# Example +```jldoctest +julia> y_true = [1, -1, 1, 1]; + +julia> y_pred = [0.1, 0.3, 1, 1.5]; + +julia> Flux.hinge_loss(y_pred, y_true) +0.55 + +julia> Flux.hinge_loss(y_pred[1], y_true[1]) != 0 # same sign but |ŷ| < 1 +true + +julia> Flux.hinge_loss(y_pred[end], y_true[end]) == 0 # same sign but |ŷ| >= 1 +true + +julia> Flux.hinge_loss(y_pred[2], y_true[2]) != 0 # opposite signs +true +``` """ function hinge_loss(ŷ, y; agg = mean) _check_sizes(ŷ, y) @@ -407,9 +449,31 @@ end squared_hinge_loss(ŷ, y) Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y` -(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`. +(containing 1 or -1); calculated as + + sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2) +Usually used with classifiers like Support Vector Machines. See also: [`hinge_loss`](@ref) + +# Example +```jldoctes +julia> y_true = [1, -1, 1, 1]; + +julia> y_pred = [0.1, 0.3, 1, 1.5]; + +julia> Flux.squared_hinge_loss(y_pred, y_true) +0.625 + +julia> Flux.squared_hinge_loss(y_pred[1], y_true[1]) != 0 +true + +julia> Flux.squared_hinge_loss(y_pred[end], y_true[end]) == 0 +true + +julia> Flux.squared_hinge_loss(y_pred[2], y_true[2]) != 0 +true +``` """ function squared_hinge_loss(ŷ, y; agg = mean) _check_sizes(ŷ, y) @@ -422,9 +486,20 @@ end Return a loss based on the dice coefficient. Used in the [V-Net](https://arxiv.org/abs/1606.04797) image segmentation architecture. -Similar to the F1_score. Calculated as: +The dice coefficient is similar to the F1_score. Loss calculated as: 1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth) + +# Example +```jldoctest +julia> y_pred = [1.1, 2.1, 3.1]; + +julia> Flux.dice_coeff_loss(y_pred, 1:3) +0.000992391663909964 + +julia> 1 - Flux.dice_coeff_loss(y_pred, 1:3) # ~ F1 score for image segmentation +0.99900760833609 +``` """ function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0)) _check_sizes(ŷ, y) @@ -436,9 +511,11 @@ end Return the [Tversky loss](https://arxiv.org/abs/1706.05721). Used with imbalanced data to give more weight to false negatives. -Larger β weigh recall more than precision (by placing more emphasis on false negatives) +Larger β weigh recall more than precision (by placing more emphasis on false negatives). Calculated as: - 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1) + + 1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1) + """ function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7)) _check_sizes(ŷ, y) @@ -456,6 +533,8 @@ The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) out For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref). +See also: [`Losses.focal_loss`](@ref) for multi-class setting + # Example ```jldoctest julia> y = [0 1 0 @@ -473,9 +552,6 @@ julia> ŷ = [0.268941 0.5 0.268941 julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385 true ``` - -See also: [`Losses.focal_loss`](@ref) for multi-class setting - """ function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ)) _check_sizes(ŷ, y) @@ -536,7 +612,17 @@ which can be useful for training Siamese Networks. It is given by agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) Specify `margin` to set the baseline for distance at which pairs are dissimilar. - + +# Example +```jldoctest +julia> ŷ = [0.5, 1.5, 2.5]; + +julia> Flux.siamese_contrastive_loss(ŷ, 1:3) +-4.833333333333333 + +julia> Flux.siamese_contrastive_loss(ŷ, 1:3, margin = 2) +-4.0 +``` """ function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1) _check_sizes(ŷ, y) diff --git a/src/utils.jl b/src/utils.jl index 85dc8b711f..ad963d5ed6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -43,16 +43,24 @@ The current defaults are: - Julia version is < 1.7: `Random.GLOBAL_RNG` - Julia version is >= 1.7: `Random.default_rng()` """ -rng_from_array(::AbstractArray) = rng_from_array() +rng_from_array(::AbstractArray) = default_rng_value() rng_from_array(::CuArray) = CUDA.default_rng() + if VERSION >= v"1.7" - rng_from_array() = Random.default_rng() + @doc """ + default_rng_value() + + Create an instance of the default RNG depending on Julia's version. + - Julia version is < 1.7: `Random.GLOBAL_RNG` + - Julia version is >= 1.7: `Random.default_rng()` + """ + default_rng_value() = Random.default_rng() else - rng_from_array() = Random.GLOBAL_RNG + default_rng_value() = Random.GLOBAL_RNG end """ - glorot_uniform([rng=GLOBAL_RNG], size...; gain = 1) -> Array + glorot_uniform([rng = default_rng_value()], size...; gain = 1) -> Array glorot_uniform([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform @@ -91,13 +99,13 @@ function glorot_uniform(rng::AbstractRNG, dims::Integer...; gain::Real=1) scale = Float32(gain) * sqrt(24.0f0 / sum(nfan(dims...))) (rand(rng, Float32, dims...) .- 0.5f0) .* scale end -glorot_uniform(dims::Integer...; kw...) = glorot_uniform(rng_from_array(), dims...; kw...) -glorot_uniform(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) +glorot_uniform(dims::Integer...; kw...) = glorot_uniform(default_rng_value(), dims...; kw...) +glorot_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_uniform(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable glorot_uniform(::Any...) """ - glorot_normal([rng=GLOBAL_RNG], size...; gain = 1) -> Array + glorot_normal([rng = default_rng_value(), size...; gain = 1) -> Array glorot_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a normal @@ -134,13 +142,13 @@ function glorot_normal(rng::AbstractRNG, dims::Integer...; gain::Real=1) std = Float32(gain) * sqrt(2.0f0 / sum(nfan(dims...))) randn(rng, Float32, dims...) .* std end -glorot_normal(dims::Integer...; kwargs...) = glorot_normal(rng_from_array(), dims...; kwargs...) -glorot_normal(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) +glorot_normal(dims::Integer...; kwargs...) = glorot_normal(default_rng_value(), dims...; kwargs...) +glorot_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> glorot_normal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable glorot_normal(::Any...) """ - kaiming_uniform([rng=GLOBAL_RNG], size...; gain = √2) -> Array + kaiming_uniform([rng = default_rng_value()], size...; gain = √2) -> Array kaiming_uniform([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a uniform distribution @@ -169,13 +177,13 @@ function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain::Real = √2) return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound end -kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...) -kaiming_uniform(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(default_rng_value(), dims...; kwargs...) +kaiming_uniform(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable kaiming_uniform(::Any...) """ - kaiming_normal([rng=GLOBAL_RNG], size...; gain = √2) -> Array + kaiming_normal([rng = default_rng_value()], size...; gain = √2) -> Array kaiming_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers taken from a normal @@ -206,13 +214,13 @@ function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain::Real = √2f0) return randn(rng, Float32, dims...) .* std end -kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...) +kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(default_rng_value(), dims...; kwargs...) kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable kaiming_normal(::Any...) """ - truncated_normal([rng=GLOBAL_RNG], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array + truncated_normal([rng = default_rng_value()], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array truncated_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. @@ -252,13 +260,13 @@ function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1, return xs end -truncated_normal(dims::Integer...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...) -truncated_normal(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) +truncated_normal(dims::Integer...; kwargs...) = truncated_normal(default_rng_value(), dims...; kwargs...) +truncated_normal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable truncated_normal(::Any...) """ - orthogonal([rng=GLOBAL_RNG], size...; gain = 1) -> Array + orthogonal([rng = default_rng_value()], size...; gain = 1) -> Array orthogonal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` which is a (semi) orthogonal matrix, as described in [1]. @@ -313,13 +321,13 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...) return reshape(orthogonal(rng, rows, cols; kwargs...), dims) end -orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...) -orthogonal(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) +orthogonal(dims::Integer...; kwargs...) = orthogonal(default_rng_value(), dims...; kwargs...) +orthogonal(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable orthogonal(::Any...) """ - sparse_init([rng=GLOBAL_RNG], rows, cols; sparsity, std = 0.01) -> Array + sparse_init([rng = default_rng_value()], rows, cols; sparsity, std = 0.01) -> Array sparse_init([rng]; kw...) -> Function Return a `Matrix{Float32}` of size `rows, cols` where each column contains a fixed fraction of @@ -361,8 +369,8 @@ function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01) return mapslices(shuffle, sparse_array, dims=1) end -sparse_init(dims::Integer...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...) -sparse_init(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) +sparse_init(dims::Integer...; kwargs...) = sparse_init(default_rng_value(), dims...; kwargs...) +sparse_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable sparse_init(::Any...) @@ -452,7 +460,7 @@ end # For consistency, it accepts an RNG, but ignores it: identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...) -identity_init(rng::AbstractRNG=rng_from_array(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) +identity_init(rng::AbstractRNG=default_rng_value(); init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) ChainRulesCore.@non_differentiable identity_init(::Any...) @@ -461,33 +469,40 @@ zeros32(dims::Integer...) = Base.zeros(Float32, dims...) """ ones32(size...) = ones(Float32, size...) - zeros32(size...) = zeros(Float32, size...) -Return an `Array{Float32}` of the given `size`. +Return an `Array{Float32}` of the given `size` filled with 1s. """ ones32(dims...) = Base.ones(Float32, dims...) -@doc @doc(ones32) +""" + zeros32(size...) = zeros(Float32, size...) + +Return an `Array{Float32}` of the given `size` filled with 0s. +""" zeros32(dims...) = Base.zeros(Float32, dims...) """ rand32([rng], size...) - randn32([rng], size...) -Return an `Array{Float32}` of the given `size`, filled like `rand` or `randn`. +Return an `Array{Float32}` of the given `size`, filled like `rand`. When the size is not provided, `rand32(rng::AbstractRNG)` returns a function. """ rand32(dims::Integer...) = Base.rand(Float32, dims...) rand32(rng::AbstractRNG, dims::Integer...) = Base.rand(rng, Float32, dims...) rand32(rng::AbstractRNG) = (dims...,) -> Base.rand(rng, Float32, dims...) -@doc @doc(rand32) +""" + randn32([rng], size...) + +Return an `Array{Float32}` of the given `size`, filled like `randn`. +When the size is not provided, `randn32(rng::AbstractRNG)` returns a function. +""" randn32(dims::Integer...) = Base.randn(Float32, dims...) randn32(rng::AbstractRNG, dims::Integer...) = Base.randn(rng, Float32, dims...) randn32(rng::AbstractRNG) = (dims...,) -> Base.randn(rng, Float32, dims...) """ - create_bias(weights, bias, size...) + _create_bias(weights, bias, size...) Return a bias parameter for a layer, based on the value given to the constructor's keyword `bias=bias`. @@ -497,10 +512,10 @@ to the constructor's keyword `bias=bias`. * `bias::AbstractArray` uses the array provided, provided it has the correct size. It does not at present correct the `eltype` to match that of `weights`. """ -function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) +function _create_bias(weights::AbstractArray, bias::Bool, dims::Integer...) bias ? fill!(similar(weights, dims...), 0) : false end -function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) +function _create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...) size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))")) bias end @@ -518,6 +533,18 @@ Normally, the throttled function will run as much as it can, without ever going more than once per `wait` duration; but if you'd like to disable the execution on the leading edge, pass `leading=false`. To enable execution on the trailing edge, pass `trailing=true`. + +# Examples +```jldoctest +julia> a = Flux.throttle(() -> println("Flux"), 2); + +julia> for i = 1:4 # a called in alternate iterations + a() + sleep(1) + end +Flux +Flux +``` """ function throttle(f, timeout; leading=true, trailing=false) cooldown = true