Skip to content

Commit

Permalink
Fix the implementation and docstring of tversky_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Jun 9, 2022
1 parent d6cc893 commit e06f5c5
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,31 +511,31 @@ end
Return the [Tversky loss](https://arxiv.org/abs/1706.05721).
Used with imbalanced data to give more weight to false negatives.
Larger β weigh recall more than precision (by placing more emphasis on false negatives)
Larger β weigh recall more than precision (by placing more emphasis on false negatives).
Calculated as:
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + β*(1 .- y) .* ŷ + (1 - β)*y .* (1 .- ŷ)) + 1)
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1)
# Example
```jldoctest
julia> = [1, 0, 1, 1, 0];
julia> y = [0, 1, 0, 1, 1, 1];
julia> y = [1, 0, 0, 1, 0]; # one false negative data point
julia> ŷ_fp = [1, 1, 1, 1, 1, 1]; # 2 false positive -> 2 wrong predictions
julia> Flux.tversky_loss(ŷ, y)
0.18918918918918926
julia> ŷ_fnp = [1, 1, 0, 1, 1, 0]; # 1 false negative, 1 false positive -> 2 wrong predictions
julia> y = [1, 1, 1, 1, 0]; # No false negatives, but a false positive
julia> Flux.tversky_loss(ŷ_fnp, y)
0.19999999999999996
julia> Flux.tversky_loss(, y) # loss is smaller as more weight given to the false negatives
0.06976744186046513
julia> Flux.tversky_loss(ŷ_fp, y) # should be smaller than tversky_loss(ŷ_fnp, y), as FN is given more weight
0.1071428571428571
```
"""
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
_check_sizes(ŷ, y)
#TODO add agg
num = sum(y .* ŷ) + 1
den = sum(y .*+ β * (1 .- y) .*+ (1 - β) * y .* (1 .- ŷ)) + 1
den = sum(y .*+ (1 - β) * (1 .- y) .*+ β * y .* (1 .- ŷ)) + 1
1 - num / den
end

Expand Down

0 comments on commit e06f5c5

Please sign in to comment.