diff --git a/src/lib/statsfuns.jl b/src/lib/statsfuns.jl index 40c7fddf4..93652a251 100644 --- a/src/lib/statsfuns.jl +++ b/src/lib/statsfuns.jl @@ -25,10 +25,7 @@ end return log1pexp(x), Δ->(Δ * (x < 9f0 ? logistic(x) : x < 16f0 ? 1 - exp(-x) : 1),) end -@adjoint function logsumexp(X::AbstractArray{<:Real}) - return logsumexp(X), function(Δ) - y = StatsFuns.softmax(X) - y .*= Δ - return (y,) - end +@adjoint function logsumexp(X::AbstractArray{<:Real}; dims=:) + lse = logsumexp(X; dims=dims) + return lse, Δ -> (Δ .* exp.(X .- lse),) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index e97de02f2..e408d1f12 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1039,6 +1039,7 @@ end @test gradtest(StatsFuns.logsumexp, randn(rng, 1, 1)) @test gradtest(StatsFuns.logsumexp, randn(rng, 3)) @test gradtest(StatsFuns.logsumexp, randn(rng, 3, 4, 5)) + @test gradtest(x -> sum(StatsFuns.logsumexp(x; dims=1)), randn(rng, 4, 4)) end end