Skip to content

Commit

Permalink
Merge pull request #1998 from Saransh-cpp/docstrings-for-utils-and-lo…
Browse files Browse the repository at this point in the history
…sses

Miscellaneous docstring additions and fixes
  • Loading branch information
ToucheSir authored Aug 17, 2022
2 parents f9b95c4 + f49ec34 commit d4f1d81
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 58 deletions.
4 changes: 4 additions & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ Flux.orthogonal
Flux.sparse_init
Flux.identity_init
Flux.ones32
Flux.zeros32
Flux.rand32
Flux.randn32
Flux.rng_from_array
Flux.default_rng_value
```

## Changing the type of model parameters
Expand Down
6 changes: 3 additions & 3 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct Dense{F, M<:AbstractMatrix, B}
bias::B
σ::F
function Dense(W::M, bias = true, σ::F = identity) where {M<:AbstractMatrix, F}
b = create_bias(W, bias, size(W,1))
b = _create_bias(W, bias, size(W,1))
new{F,M,typeof(b)}(W, b, σ)
end
end
Expand Down Expand Up @@ -228,7 +228,7 @@ struct Scale{F, A<:AbstractArray, B}
bias::B
σ::F
function Scale(scale::A, bias::B = true, σ::F = identity) where {A<:AbstractArray, B<:Union{Bool, AbstractArray}, F}
b = create_bias(scale, bias, size(scale)...)
b = _create_bias(scale, bias, size(scale)...)
new{F, A, typeof(b)}(scale, b, σ)
end
end
Expand Down Expand Up @@ -403,7 +403,7 @@ struct Bilinear{F,A,B}
σ::F
function Bilinear(W::A, bias = true, σ::F = identity) where {A<:AbstractArray, F}
ndims(A) == 3 || throw(ArgumentError("expected a 3-array of weights"))
b = create_bias(W, bias, size(W,1))
b = _create_bias(W, bias, size(W,1))
new{F,A,typeof(b)}(W, b, σ)
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function Conv(w::AbstractArray{T,N}, b = true, σ = identity;
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(Conv, pad, size(w)[1:N-2], dilation, stride)
bias = create_bias(w, b, size(w, N))
bias = _create_bias(w, b, size(w, N))
return Conv(σ, w, bias, stride, pad, dilation, groups)
end

Expand Down Expand Up @@ -293,7 +293,7 @@ function ConvTranspose(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(ConvTranspose, pad, size(w)[1:N-2], dilation, stride)
b = create_bias(w, bias, size(w, N-1) * groups)
b = _create_bias(w, bias, size(w, N-1) * groups)
return ConvTranspose(σ, w, b, stride, pad, dilation, groups)
end

Expand Down Expand Up @@ -441,7 +441,7 @@ function CrossCor(w::AbstractArray{T,N}, bias = true, σ = identity;
stride = expand(Val(N-2), stride)
dilation = expand(Val(N-2), dilation)
pad = calc_padding(CrossCor, pad, size(w)[1:N-2], dilation, stride)
b = create_bias(w, bias, size(w, N))
b = _create_bias(w, bias, size(w, N))
return CrossCor(σ, w, b, stride, pad, dilation)
end

Expand Down
12 changes: 6 additions & 6 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)

"""
Dropout(p; dims=:, rng = rng_from_array())
Dropout(p; dims=:, rng = default_rng_value())
Dropout layer.
Expand Down Expand Up @@ -96,9 +96,9 @@ mutable struct Dropout{F,D,R<:AbstractRNG}
active::Union{Bool, Nothing}
rng::R
end
Dropout(p, dims, active) = Dropout(p, dims, active, rng_from_array())
Dropout(p, dims, active) = Dropout(p, dims, active, default_rng_value())

function Dropout(p; dims=:, rng = rng_from_array())
function Dropout(p; dims=:, rng = default_rng_value())
@assert 0 p 1
Dropout(p, dims, nothing, rng)
end
Expand All @@ -121,7 +121,7 @@ function Base.show(io::IO, d::Dropout)
end

"""
AlphaDropout(p; rng = rng_from_array())
AlphaDropout(p; rng = default_rng_value())
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
Expand Down Expand Up @@ -155,8 +155,8 @@ mutable struct AlphaDropout{F,R<:AbstractRNG}
new{typeof(p), typeof(rng)}(p, active, rng)
end
end
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)
AlphaDropout(p, active) = AlphaDropout(p, active, default_rng_value())
AlphaDropout(p; rng = default_rng_value()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout
trainable(a::AlphaDropout) = (;)
Expand Down
19 changes: 19 additions & 0 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ end
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
Per default, `dims` is the last dimension.
`ϵ` is a small additive factor added to the denominator for numerical stability.
# Examples
```jldoctest
julia> using Statistics
julia> x = [9, 10, 20, 60];
julia> y = Flux.normalise(x);
julia> isapprox(std(y), 1, atol=0.2) && std(y) != std(x)
true
julia> x = rand(1:100, 10, 2);
julia> y = Flux.normalise(x, dims=1);
julia> isapprox(std(y, dims=1), ones(1, 2), atol=0.2) && std(y, dims=1) != std(x, dims=1)
true
```
"""
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
μ = mean(x, dims=dims)
Expand Down
112 changes: 99 additions & 13 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ given the prediction `ŷ` and true values `y`.
| 0.5 * |ŷ - y|^2, for |ŷ - y| <= δ
Huber loss = |
| δ * (|ŷ - y| - 0.5 * δ), otherwise
# Example
```jldoctest
julia> ŷ = [1.1, 2.1, 3.1];
julia> Flux.huber_loss(ŷ, 1:3) # default δ = 1 > |ŷ - y|
0.005000000000000009
julia> Flux.huber_loss(ŷ, 1:3, δ=0.05) # changes behaviour as |ŷ - y| > δ
0.003750000000000005
```
"""
function huber_loss(ŷ, y; agg = mean, δ = ofeltype(ŷ, 1))
_check_sizes(ŷ, y)
Expand Down Expand Up @@ -377,12 +388,22 @@ function kldivergence(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ))
end

"""
poisson_loss(ŷ, y)
poisson_loss(ŷ, y; agg = mean)
# Return how much the predicted distribution `ŷ` diverges from the expected Poisson
# distribution `y`; calculated as `sum(ŷ .- y .* log.(ŷ)) / size(y, 2)`.
Return how much the predicted distribution `ŷ` diverges from the expected Poisson
distribution `y`; calculated as -
sum(ŷ .- y .* log.(ŷ)) / size(y, 2)
[More information.](https://peltarion.com/knowledge-center/documentation/modeling-view/build-an-ai-model/loss-functions/poisson).
# Example
```jldoctest
julia> y_model = [1, 3, 3]; # data should only take integral values
julia> Flux.poisson_loss(y_model, 1:3)
0.5023128522198171
```
"""
function poisson_loss(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
Expand All @@ -392,11 +413,32 @@ end
"""
hinge_loss(ŷ, y; agg = mean)
Return the [hinge_loss loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
Return the [hinge_loss](https://en.wikipedia.org/wiki/Hinge_loss) given the
prediction `ŷ` and true labels `y` (containing 1 or -1); calculated as
`sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)`.
sum(max.(0, 1 .- ŷ .* y)) / size(y, 2)
Usually used with classifiers like Support Vector Machines.
See also: [`squared_hinge_loss`](@ref)
# Example
```jldoctest
julia> y_true = [1, -1, 1, 1];
julia> y_pred = [0.1, 0.3, 1, 1.5];
julia> Flux.hinge_loss(y_pred, y_true)
0.55
julia> Flux.hinge_loss(y_pred[1], y_true[1]) != 0 # same sign but |ŷ| < 1
true
julia> Flux.hinge_loss(y_pred[end], y_true[end]) == 0 # same sign but |ŷ| >= 1
true
julia> Flux.hinge_loss(y_pred[2], y_true[2]) != 0 # opposite signs
true
```
"""
function hinge_loss(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
Expand All @@ -407,9 +449,31 @@ end
squared_hinge_loss(ŷ, y)
Return the squared hinge_loss loss given the prediction `ŷ` and true labels `y`
(containing 1 or -1); calculated as `sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)`.
(containing 1 or -1); calculated as
sum((max.(0, 1 .- ŷ .* y)).^2) / size(y, 2)
Usually used with classifiers like Support Vector Machines.
See also: [`hinge_loss`](@ref)
# Example
```jldoctes
julia> y_true = [1, -1, 1, 1];
julia> y_pred = [0.1, 0.3, 1, 1.5];
julia> Flux.squared_hinge_loss(y_pred, y_true)
0.625
julia> Flux.squared_hinge_loss(y_pred[1], y_true[1]) != 0
true
julia> Flux.squared_hinge_loss(y_pred[end], y_true[end]) == 0
true
julia> Flux.squared_hinge_loss(y_pred[2], y_true[2]) != 0
true
```
"""
function squared_hinge_loss(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
Expand All @@ -422,9 +486,20 @@ end
Return a loss based on the dice coefficient.
Used in the [V-Net](https://arxiv.org/abs/1606.04797) image segmentation
architecture.
Similar to the F1_score. Calculated as:
The dice coefficient is similar to the F1_score. Loss calculated as:
1 - 2*sum(|ŷ .* y| + smooth) / (sum(ŷ.^2) + sum(y.^2) + smooth)
# Example
```jldoctest
julia> y_pred = [1.1, 2.1, 3.1];
julia> Flux.dice_coeff_loss(y_pred, 1:3)
0.000992391663909964
julia> 1 - Flux.dice_coeff_loss(y_pred, 1:3) # ~ F1 score for image segmentation
0.99900760833609
```
"""
function dice_coeff_loss(ŷ, y; smooth = ofeltype(ŷ, 1.0))
_check_sizes(ŷ, y)
Expand All @@ -436,9 +511,11 @@ end
Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall more than precision (by placing more emphasis on false negatives)
Larger β weigh recall more than precision (by placing more emphasis on false negatives).
Calculated as:
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1)
"""
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
_check_sizes(ŷ, y)
Expand All @@ -456,6 +533,8 @@ The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) out
For `γ == 0`, the loss is mathematically equivalent to [`Losses.binarycrossentropy`](@ref).
See also: [`Losses.focal_loss`](@ref) for multi-class setting
# Example
```jldoctest
julia> y = [0 1 0
Expand All @@ -473,9 +552,6 @@ julia> ŷ = [0.268941 0.5 0.268941
julia> Flux.binary_focal_loss(ŷ, y) ≈ 0.0728675615927385
true
```
See also: [`Losses.focal_loss`](@ref) for multi-class setting
"""
function binary_focal_loss(ŷ, y; agg=mean, γ=2, ϵ=epseltype(ŷ))
_check_sizes(ŷ, y)
Expand Down Expand Up @@ -536,7 +612,17 @@ which can be useful for training Siamese Networks. It is given by
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
Specify `margin` to set the baseline for distance at which pairs are dissimilar.
# Example
```jldoctest
julia> ŷ = [0.5, 1.5, 2.5];
julia> Flux.siamese_contrastive_loss(ŷ, 1:3)
-4.833333333333333
julia> Flux.siamese_contrastive_loss(ŷ, 1:3, margin = 2)
-4.0
```
"""
function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
_check_sizes(ŷ, y)
Expand Down
Loading

0 comments on commit d4f1d81

Please sign in to comment.