Skip to content

Commit

Permalink
Revert tversky_loss changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Aug 16, 2022
1 parent a67400f commit f49ec34
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 17 deletions.
16 changes: 1 addition & 15 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,26 +516,12 @@ Calculated as:
1 - sum(|y .* ŷ| + 1) / (sum(y .* ŷ + (1 - β)*(1 .- y) .* ŷ + β*y .* (1 .- ŷ)) + 1)
# Example
```jldoctest
julia> y = [0, 1, 0, 1, 1, 1];
julia> ŷ_fp = [1, 1, 1, 1, 1, 1]; # 2 false positive -> 2 wrong predictions
julia> ŷ_fnp = [1, 1, 0, 1, 1, 0]; # 1 false negative, 1 false positive -> 2 wrong predictions
julia> Flux.tversky_loss(ŷ_fnp, y)
0.19999999999999996
julia> Flux.tversky_loss(ŷ_fp, y) < Flux.tversky_loss(ŷ_fnp, y) # FN is given more weight
true
```
"""
function tversky_loss(ŷ, y; β = ofeltype(ŷ, 0.7))
_check_sizes(ŷ, y)
#TODO add agg
num = sum(y .* ŷ) + 1
den = sum(y .*+ (1 - β) * (1 .- y) .*+ β * y .* (1 .- ŷ)) + 1
den = sum(y .*+ β * (1 .- y) .*+ (1 - β) * y .* (1 .- ŷ)) + 1
1 - num / den
end

Expand Down
4 changes: 2 additions & 2 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ y = [1.0 0.5 0.3 2.4]
end

@testset "tversky_loss" begin
@test Flux.tversky_loss(ŷ, y) 0.028747433264887046
@test Flux.tversky_loss(ŷ, y, β=0.8) 0.050200803212851364
@test Flux.tversky_loss(ŷ, y) -0.06772009029345383
@test Flux.tversky_loss(ŷ, y, β=0.8) -0.09490740740740744
@test Flux.tversky_loss(y, y) -0.5576923076923075
end

Expand Down

0 comments on commit f49ec34

Please sign in to comment.