Skip to content

Commit

Permalink
Simplify softmax, test second derivatives (#393)
Browse files Browse the repository at this point in the history
* simplify softmax, test second derivatives

* add a note about Flux to docstring

* add Tracker to downstream tests

* missing semicolon

* remove x arguments, rename

* move exports

* change the notation

* tidy, add x::AbstractArray

* add fastmath

* also logsumexp

* version, and trigger CI

* upgrade ci
  • Loading branch information
mcabbott authored Mar 5, 2022
1 parent 172549c commit 0c8396e
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 75 deletions.
27 changes: 13 additions & 14 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,19 @@ steps:
NNLIB_TEST_CUDA: true
timeout_in_minutes: 60

## Add these when julia 1.7 is out
# - label: "GPU julia v1"
# plugins:
# - JuliaCI/julia#v1:
# version: "1"
# - JuliaCI/julia-test#v1: ~
# - JuliaCI/julia-coverage#v1:
# codecov: true
# dirs:
# - src
# agents:
# queue: "juliagpu"
# cuda: "*"
# timeout_in_minutes: 60
- label: "GPU julia v1"
plugins:
- JuliaCI/julia#v1:
version: "1"
- JuliaCI/julia-test#v1: ~
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
agents:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60

# - label: "GPU julia nightly"
# plugins:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
os: [ubuntu-latest]
package:
- {user: FluxML, repo: Flux.jl, group: All}
- {user: FluxML, repo: Tracker.jl, group: All}
- {user: denizyuret, repo: Knet.jl, group: All}
- {user: dfdx, repo: Avalon.jl, group: All}
- {user: JuliaOptimalTransport, repo: OptimalTransport.jl, group: All}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.2"
version = "0.8.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Adapt = "2, 3.2"
ChainRulesCore = "0.9.45, 0.10, 1"
ChainRulesCore = "1.13"
Compat = "3.14"
Requires = "0.5, 1.0"
julia = "1.6"
Expand Down
57 changes: 56 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1 +1,56 @@
### v0.8 Deprecations

### Deprecated while v0.7 was latest

function ∇softmax(Δ, x; dims = 1)
# This 2-arg version recomputes the forward pass, which is slow.
# Removed from use in 0.7, but only prints a warning during 0.8:
Base.depwarn("`∇softmax(Δ, x)` without `y = softmax(x)` argument is deprecated, as this is inefficient, please use `∇softmax_data(dy, y)`", :∇softmax)
∇softmax(Δ, x, softmax(x; dims); dims)
end
∇softmax!(Δ, x; dims = 1) = Δ .= ∇softmax(Δ, x; dims)
∇softmax!(out, Δ, x; dims = 1) = out .= ∇softmax(Δ, x; dims)

function ∇logsoftmax(Δ, x; dims = 1)
Base.depwarn("`∇logsoftmax(Δ, x)` without `y = logsoftmax(x)` argument is deprecated, please use `∇logsoftmax_data(dy, y)`", :∇logsoftmax)
∇logsoftmax(Δ, x, logsoftmax(x; dims); dims)
end
∇logsoftmax!(Δ, x; dims = 1) = Δ .= ∇logsoftmax(Δ, x; dims)
∇logsoftmax!(out, Δ, x; dims = 1) = out .= ∇logsoftmax(Δ, x; dims)


### Deprecated while v0.8 was latest

export ∇softmax,
∇softmax!,
logsoftmax,
logsoftmax!,
∇logsoftmax,
∇logsoftmax!

function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
x::AbstractArray, y::AbstractArray; dims = 1)
Base.depwarn("`∇softmax!(dx, dy, x, y)` is deprecated, just use `∇softmax_data(dy, y)`", :∇softmax!)
# Removed because using a mutating function blocks 2nd derivatives, and
# the CUDA overload was slow anyway, https://github.com/FluxML/NNlibCUDA.jl/issues/30
out .= Δ .* y
out .= out .- y .* sum(out; dims)
end

function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
x::AbstractArray, y::AbstractArray; dims = 1)
Base.depwarn("`∇logsoftmax!(dx, dy, x, y)` is deprecated, just use `∇logsoftmax_data(dy, y)`", :∇softmax!)
out .= Δ .- sum(Δ; dims) .* exp.(y)
end

function ∇softmax(dy::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S}
# Removed because there's no need to close over `x` here, that was done only to distinguish
# this from `∇softmax(Δ, x; dims = 1)` which re-computed `y = softmax(x)`, which is slow.
Base.depwarn("`∇softmax(dy, x, y)` should be replaced with `∇softmax_data(dy, y)`", :∇softmax)
∇softmax_data(dy, y)
end

function ∇logsoftmax(dy::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1)
Base.depwarn("`∇logsoftmax(dy, x, y)` should be replaced with `∇logsoftmax_data(dy, y)`", :∇softmax)
∇logsoftmax_data(dy, y)
end

117 changes: 68 additions & 49 deletions src/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

"""
softmax(x; dims = 1)
Expand Down Expand Up @@ -33,45 +34,63 @@ julia> softmax([1 2 3; 2 2 2]; dims=2)
0.0900306 0.244728 0.665241
0.333333 0.333333 0.333333
```
Note that, when used with Flux.jl, `softmax` must not be passed to layers like `Dense`
which accept an activation function. The activation is broadcasted over the result,
thus applies to individual numbers. But `softmax` always needs to see the whole column.
```julia
julia> using Flux
julia> x = randn(Float32, 4, 4, 3, 13);
julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax);
julia> model(x) |> size
(7, 13)
julia> Dense(4 => 7, softmax)(x)
ERROR: `softmax(x)` called with a number, but it expects an array.
```
"""
softmax(x; dims = 1) = softmax!(similar(x, (float eltype)(x)), x; dims = dims)
softmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T)), x; dims)

softmax!(x; dims = 1) = softmax!(x, x; dims = dims)
softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims)

function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = maximum(x; dims = dims)
max_ = maximum(x; dims)
if all(isfinite, max_)
out .= exp.(x .- max_)
@fastmath out .= exp.(x .- max_)
else
@. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
@fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_))
end
out ./= sum(out; dims = dims) # could re-use max_ when dims != (:) and eltype(x) == T.
out ./= sum(out; dims)
end

∇softmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
∇softmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
∇softmax(Δ, x, y; dims = 1) = ∇softmax(unthunk(Δ), x, y, dims = dims)

# Can introduce at the end of deprecation cycle of ∇softmax!(out, Δ, x; dims = 1)
# ∇softmax!(Δ, x, y; dims = 1) = ∇softmax!(Δ, Δ, x, y; dims = dims)

function ∇softmax!(out::AbstractArray, Δ::AbstractArray,
x::AbstractArray, y::AbstractArray; dims = 1)
out .= Δ .* y
out .= out .- y .* sum(out; dims = dims)
function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
dx = if within_grad()
tmp = dy .* y
tmp .- y .* sum(tmp; dims)
else
# This path is faster, only safe for 1st derivatives though.
# Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads,
# but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30
out = similar(y, promote_type(T,S))
out .= dy .* y
out .= out .- y .* sum(out; dims)
end
end

# Old 2-arg version recomputing forward
∇softmax(Δ, x; dims = 1) = ∇softmax(Δ, x, softmax(x, dims = dims); dims = dims)
∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x, softmax(x, dims = dims); dims = dims)
∇softmax!(out, Δ, x; dims = 1) = ∇softmax!(out, Δ, x, softmax(x, dims = dims); dims = dims)

function rrule(::typeof(softmax), xs; dims=1)
y = softmax(xs; dims=dims)
softmax_pullback(Δ) = (NoTangent(), ∇softmax(unthunk(Δ), xs, y, dims = dims))
function rrule(::typeof(softmax), x; dims = 1)
y = softmax(x; dims)
softmax_pullback(dy) = (NoTangent(), ∇softmax_data(unthunk(dy), y; dims))
return y, softmax_pullback
end

within_grad() = false
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)


"""
logsoftmax(x; dims = 1)
Expand All @@ -85,52 +104,52 @@ It is semantically equivalent to the following:
See also [`softmax`](@ref).
"""
logsoftmax(x; dims = 1) = logsoftmax!(similar(x, (float eltype)(x)), x; dims = dims)
logsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, float(T)), x; dims)

logsoftmax!(x; dims = 1) = logsoftmax!(x, x; dims = dims)
logsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims)

function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
max_ = maximum(x; dims = dims)
max_ = maximum(x; dims)
if all(isfinite, max_)
out .= x .- max_
else
@. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 0, -Inf), x - max_)
end
log_ = log.(sum(exp, out; dims = dims))
@fastmath log_ = log.(sum(exp, out; dims))
out .-= log_
end

∇logsoftmax::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} =
∇logsoftmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims)
∇logsoftmax(Δ, x, y; dims = 1) = ∇logsoftmax(unthunk(Δ), x, y, dims = dims)

# Old 2-arg version recomputing forward
∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax(Δ, x, logsoftmax(x, dims = dims); dims = dims)
∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x, logsoftmax(x, dims = dims); dims = dims)
∇logsoftmax!(out, Δ, x; dims = 1) = ∇logsoftmax!(out, Δ, x, logsoftmax(x, dims = dims); dims = dims)

function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray,
x::AbstractArray, y::AbstractArray; dims = 1)
out .= Δ .- sum(Δ, dims = dims) .* exp.(y)
function ∇logsoftmax_data(dy::AbstractArray, y::AbstractArray; dims = 1)
# This was previously `∇logsoftmax!(dx, dy, x, y; dims)` to allow CUDA overloads, but that was slow.
dx = dy .- sum(dy; dims) .* exp.(y)
end

function rrule(::typeof(logsoftmax), xs; dims=1)
y = logsoftmax(xs; dims=dims)
logsoftmax_pullback(Δ) = (NoTangent(), logsoftmax(unthunk(Δ), xs, y, dims = dims))
function rrule(::typeof(logsoftmax), x; dims = 1)
y = logsoftmax(x; dims)
logsoftmax_pullback(dy) = (NoTangent(), logsoftmax_data(unthunk(dy), y; dims))
return y, logsoftmax_pullback
end

"""
logsumexp(x; dims = :)
Computes `log.(sum(exp.(x); dims = dims))` in a numerically stable
way.
Computes `log.(sum(exp.(x); dims))` in a numerically stable way.
Without `dims` keyword this returns a scalar.
See also [`logsoftmax`](@ref).
"""
function logsumexp(x::AbstractArray; dims = :)
max_ = maximum(x; dims = dims)
max_ .+ log.(sum(exp.(x .- max_); dims = dims))
max_ = maximum(x; dims)
@fastmath max_ .+ log.(sum(exp.(x .- max_); dims))
end

function rrule(::typeof(logsumexp), x; dims = :)
# The gradient is `softmax`, but both compute `tmp` so it's worth saving.
max_ = maximum(x; dims)
@fastmath tmp = exp.(x .- max_)
@fastmath y = max_ .+ log.(sum(tmp; dims))
logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims))
return y, logsumexp_pullback
end

# Informative error message if any of the softmax variants is called with a number
Expand Down
31 changes: 22 additions & 9 deletions test/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Statistics: mean
using NNlib: ∇softmax_data, ∇logsoftmax_data

@testset "softmax integer input" begin
@test softmax(Int[0, 0]) == [0.5, 0.5]
Expand Down Expand Up @@ -34,10 +35,10 @@ end
@test logsoftmax(xs) [-999 -1998 -2997; 0 0 0.0]

y = logsoftmax(xs)
@test logsoftmax(ones(Float32, size(xs)), xs, y) Float32[1 1 1; -1 -1 -1]
@test logsoftmax_data(ones(Float32, size(xs)), y) Float32[1 1 1; -1 -1 -1]

y = softmax(xs)
@test softmax(ones(Float32, size(xs)), xs, y) zeros(Float32, size(xs))
@test softmax_data(ones(Float32, size(xs)), y) zeros(Float32, size(xs))

# These values precalculated using PyTorch's nn.LogSoftmax
xs = [
Expand All @@ -52,10 +53,10 @@ end
]

y = logsoftmax(xs)
@test logsoftmax(ones(size(xs)), xs, y) ys rtol = 1e-6
@test logsoftmax_data(ones(size(xs)), y) ys rtol = 1e-6

y = softmax(xs)
@test softmax(ones(size(xs)), xs, y) zeros(size(xs)) atol = 1e-6
@test softmax_data(ones(size(xs)), y) zeros(size(xs)) atol = 1e-6
end

@testset "softmax with Inf, NaN" begin
Expand Down Expand Up @@ -91,12 +92,12 @@ end
@testset "$fn(Float64, $(size(xs)))" for fn in [zeros, ones, rand]
Δ = fn(Float64, size(xs))
y = softmax(xs)
∇softmax!(out, Δ, xs, y)
@test out softmax(Δ, xs, y) rtol = 1e-6
∇softmax!(out, Δ, xs, y) # deprecated
@test out softmax_data, y) rtol = 1e-6

y = logsoftmax(xs)
∇logsoftmax!(out, Δ, xs, y)
@test out logsoftmax(Δ, xs, y) rtol = 1e-6
∇logsoftmax!(out, Δ, xs, y) # deprecated
@test out logsoftmax_data, y) rtol = 1e-6
end
end
end
Expand All @@ -109,14 +110,14 @@ end
@test logsumexp(x; dims = 1) flogsoft(x, dims = 1)
end


@testset "AutoDiff" begin
for f in (softmax, logsoftmax), d in (:, 1, 2)
gradtest(f, (3,4); fkwargs = (dims = d,), check_rrule = true)
end
gradtest(x -> softmax(x) .* (1:3), 3)
gradtest(x -> softmax(x) .* (1:3), (3,5), atol = 1e-4)
gradtest(x -> softmax(x, dims = 2) .* (1:3), (3,5), atol = 1e-4)

gradtest(x -> logsoftmax(x) .* (1:3), 3)
gradtest(x -> logsoftmax(x) .* (1:3), (3,5))
gradtest(x -> logsoftmax(x, dims = 2) .* (1:3), (3,5))
Expand All @@ -125,3 +126,15 @@ end
gradtest(logsumexp, (3,4), fkwargs = (dims = d,))
end
end

@testset "Second derivatives" begin
x = [1 2 3; 6 5 4]
H = Zygote.hessian_dual(x -> sum(sin, softmax(x)), x)
@test H Zygote.hessian_reverse(x -> sum(sin, softmax(x)), x)

H2 = Zygote.hessian_dual(x -> sum(sin, logsoftmax(x)), x)
@test H2 Zygote.hessian_reverse(x -> sum(sin, logsoftmax(x)), x)

H3 = Zygote.hessian_dual(x -> sum(sin, logsumexp(x)), x)
@test H3 Zygote.hessian_reverse(x -> sum(sin, logsumexp(x)), x)
end

2 comments on commit 0c8396e

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56021

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.3 -m "<description of version>" 0c8396e2f2707d4c223fb45348897eae28b62e2e
git push origin v0.8.3

Please sign in to comment.