diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 7dd029e75..f805f8c43 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -24,20 +24,19 @@ steps: NNLIB_TEST_CUDA: true timeout_in_minutes: 60 - ## Add these when julia 1.7 is out - # - label: "GPU julia v1" - # plugins: - # - JuliaCI/julia#v1: - # version: "1" - # - JuliaCI/julia-test#v1: ~ - # - JuliaCI/julia-coverage#v1: - # codecov: true - # dirs: - # - src - # agents: - # queue: "juliagpu" - # cuda: "*" - # timeout_in_minutes: 60 + - label: "GPU julia v1" + plugins: + - JuliaCI/julia#v1: + version: "1" + - JuliaCI/julia-test#v1: ~ + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + agents: + queue: "juliagpu" + cuda: "*" + timeout_in_minutes: 60 # - label: "GPU julia nightly" # plugins: diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 79d0c51b6..1bd1c511d 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -18,6 +18,7 @@ jobs: os: [ubuntu-latest] package: - {user: FluxML, repo: Flux.jl, group: All} + - {user: FluxML, repo: Tracker.jl, group: All} - {user: denizyuret, repo: Knet.jl, group: All} - {user: dfdx, repo: Avalon.jl, group: All} - {user: JuliaOptimalTransport, repo: OptimalTransport.jl, group: All} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b3a5e3506..4ae4e4bbe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,6 +23,7 @@ jobs: fail-fast: false matrix: version: + - '1.6' - '1' # automatically expands to the latest stable 1.x release of Julia - 'nightly' os: diff --git a/Project.toml b/Project.toml index 496560f35..339cd1098 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "NNlib" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.2" +version = "0.8.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,7 +13,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Adapt = "2, 3.2" -ChainRulesCore = "0.9.45, 0.10, 1" +ChainRulesCore = "1.13" Compat = "3.14" Requires = "0.5, 1.0" julia = "1.6" diff --git a/src/deprecations.jl b/src/deprecations.jl index 6791fbae3..e76cbc796 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1 +1,56 @@ -### v0.8 Deprecations + +### Deprecated while v0.7 was latest + +function ∇softmax(Δ, x; dims = 1) + # This 2-arg version recomputes the forward pass, which is slow. + # Removed from use in 0.7, but only prints a warning during 0.8: + Base.depwarn("`∇softmax(Δ, x)` without `y = softmax(x)` argument is deprecated, as this is inefficient, please use `∇softmax_data(dy, y)`", :∇softmax) + ∇softmax(Δ, x, softmax(x; dims); dims) +end +∇softmax!(Δ, x; dims = 1) = Δ .= ∇softmax(Δ, x; dims) +∇softmax!(out, Δ, x; dims = 1) = out .= ∇softmax(Δ, x; dims) + +function ∇logsoftmax(Δ, x; dims = 1) + Base.depwarn("`∇logsoftmax(Δ, x)` without `y = logsoftmax(x)` argument is deprecated, please use `∇logsoftmax_data(dy, y)`", :∇logsoftmax) + ∇logsoftmax(Δ, x, logsoftmax(x; dims); dims) +end +∇logsoftmax!(Δ, x; dims = 1) = Δ .= ∇logsoftmax(Δ, x; dims) +∇logsoftmax!(out, Δ, x; dims = 1) = out .= ∇logsoftmax(Δ, x; dims) + + +### Deprecated while v0.8 was latest + +export ∇softmax, + ∇softmax!, + logsoftmax, + logsoftmax!, + ∇logsoftmax, + ∇logsoftmax! + +function ∇softmax!(out::AbstractArray, Δ::AbstractArray, + x::AbstractArray, y::AbstractArray; dims = 1) + Base.depwarn("`∇softmax!(dx, dy, x, y)` is deprecated, just use `∇softmax_data(dy, y)`", :∇softmax!) + # Removed because using a mutating function blocks 2nd derivatives, and + # the CUDA overload was slow anyway, https://github.com/FluxML/NNlibCUDA.jl/issues/30 + out .= Δ .* y + out .= out .- y .* sum(out; dims) +end + +function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray, + x::AbstractArray, y::AbstractArray; dims = 1) + Base.depwarn("`∇logsoftmax!(dx, dy, x, y)` is deprecated, just use `∇logsoftmax_data(dy, y)`", :∇softmax!) + out .= Δ .- sum(Δ; dims) .* exp.(y) +end + +function ∇softmax(dy::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} + # Removed because there's no need to close over `x` here, that was done only to distinguish + # this from `∇softmax(Δ, x; dims = 1)` which re-computed `y = softmax(x)`, which is slow. + Base.depwarn("`∇softmax(dy, x, y)` should be replaced with `∇softmax_data(dy, y)`", :∇softmax) + ∇softmax_data(dy, y) +end + +function ∇logsoftmax(dy::AbstractArray, x::AbstractArray, y::AbstractArray; dims = 1) + Base.depwarn("`∇logsoftmax(dy, x, y)` should be replaced with `∇logsoftmax_data(dy, y)`", :∇softmax) + ∇logsoftmax_data(dy, y) +end + diff --git a/src/softmax.jl b/src/softmax.jl index 89807c4c6..1596d9b46 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -1,3 +1,4 @@ + """ softmax(x; dims = 1) @@ -33,45 +34,63 @@ julia> softmax([1 2 3; 2 2 2]; dims=2) 0.0900306 0.244728 0.665241 0.333333 0.333333 0.333333 ``` + +Note that, when used with Flux.jl, `softmax` must not be passed to layers like `Dense` +which accept an activation function. The activation is broadcasted over the result, +thus applies to individual numbers. But `softmax` always needs to see the whole column. + +```julia +julia> using Flux + +julia> x = randn(Float32, 4, 4, 3, 13); + +julia> model = Chain(Conv((4, 4), 3 => 8, tanh), Flux.flatten, Dense(8 => 7), softmax); + +julia> model(x) |> size +(7, 13) + +julia> Dense(4 => 7, softmax)(x) +ERROR: `softmax(x)` called with a number, but it expects an array. +``` """ -softmax(x; dims = 1) = softmax!(similar(x, (float ∘ eltype)(x)), x; dims = dims) +softmax(x::AbstractArray{T}; dims = 1) where {T} = softmax!(similar(x, float(T)), x; dims) -softmax!(x; dims = 1) = softmax!(x, x; dims = dims) +softmax!(x::AbstractArray; dims = 1) = softmax!(x, x; dims) function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} - max_ = maximum(x; dims = dims) + max_ = maximum(x; dims) if all(isfinite, max_) - out .= exp.(x .- max_) + @fastmath out .= exp.(x .- max_) else - @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_)) + @fastmath @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 1, 0), exp(x - max_)) end - out ./= sum(out; dims = dims) # could re-use max_ when dims != (:) and eltype(x) == T. + out ./= sum(out; dims) end -∇softmax(Δ::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} = - ∇softmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims) -∇softmax(Δ, x, y; dims = 1) = ∇softmax(unthunk(Δ), x, y, dims = dims) - -# Can introduce at the end of deprecation cycle of ∇softmax!(out, Δ, x; dims = 1) -# ∇softmax!(Δ, x, y; dims = 1) = ∇softmax!(Δ, Δ, x, y; dims = dims) - -function ∇softmax!(out::AbstractArray, Δ::AbstractArray, - x::AbstractArray, y::AbstractArray; dims = 1) - out .= Δ .* y - out .= out .- y .* sum(out; dims = dims) +function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S} + dx = if within_grad() + tmp = dy .* y + tmp .- y .* sum(tmp; dims) + else + # This path is faster, only safe for 1st derivatives though. + # Was previously `∇softmax!(dx, dy, x, y; dims)` to allow CUDA overloads, + # but that was slow: https://github.com/FluxML/NNlibCUDA.jl/issues/30 + out = similar(y, promote_type(T,S)) + out .= dy .* y + out .= out .- y .* sum(out; dims) + end end -# Old 2-arg version recomputing forward -∇softmax(Δ, x; dims = 1) = ∇softmax(Δ, x, softmax(x, dims = dims); dims = dims) -∇softmax!(Δ, x; dims = 1) = ∇softmax!(Δ, Δ, x, softmax(x, dims = dims); dims = dims) -∇softmax!(out, Δ, x; dims = 1) = ∇softmax!(out, Δ, x, softmax(x, dims = dims); dims = dims) - -function rrule(::typeof(softmax), xs; dims=1) - y = softmax(xs; dims=dims) - softmax_pullback(Δ) = (NoTangent(), ∇softmax(unthunk(Δ), xs, y, dims = dims)) +function rrule(::typeof(softmax), x; dims = 1) + y = softmax(x; dims) + softmax_pullback(dy) = (NoTangent(), ∇softmax_data(unthunk(dy), y; dims)) return y, softmax_pullback end +within_grad() = false +rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),) + + """ logsoftmax(x; dims = 1) @@ -85,52 +104,52 @@ It is semantically equivalent to the following: See also [`softmax`](@ref). """ -logsoftmax(x; dims = 1) = logsoftmax!(similar(x, (float ∘ eltype)(x)), x; dims = dims) +logsoftmax(x::AbstractArray{T}; dims = 1) where {T} = logsoftmax!(similar(x, float(T)), x; dims) -logsoftmax!(x; dims = 1) = logsoftmax!(x, x; dims = dims) +logsoftmax!(x::AbstractArray; dims = 1) = logsoftmax!(x, x; dims) function logsoftmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} - max_ = maximum(x; dims = dims) + max_ = maximum(x; dims) if all(isfinite, max_) out .= x .- max_ else @. out = ifelse(isequal(max_,Inf), ifelse(isequal(x,Inf), 0, -Inf), x - max_) end - log_ = log.(sum(exp, out; dims = dims)) + @fastmath log_ = log.(sum(exp, out; dims)) out .-= log_ end -∇logsoftmax(Δ::AbstractArray{T}, x::AbstractArray, y::AbstractArray{S}; dims = 1) where {T,S} = - ∇logsoftmax!(similar(y, promote_type(T, S)), Δ, x, y; dims = dims) -∇logsoftmax(Δ, x, y; dims = 1) = ∇logsoftmax(unthunk(Δ), x, y, dims = dims) - -# Old 2-arg version recomputing forward -∇logsoftmax(Δ, x; dims = 1) = ∇logsoftmax(Δ, x, logsoftmax(x, dims = dims); dims = dims) -∇logsoftmax!(Δ, x; dims = 1) = ∇logsoftmax!(Δ, Δ, x, logsoftmax(x, dims = dims); dims = dims) -∇logsoftmax!(out, Δ, x; dims = 1) = ∇logsoftmax!(out, Δ, x, logsoftmax(x, dims = dims); dims = dims) - -function ∇logsoftmax!(out::AbstractArray, Δ::AbstractArray, - x::AbstractArray, y::AbstractArray; dims = 1) - out .= Δ .- sum(Δ, dims = dims) .* exp.(y) +function ∇logsoftmax_data(dy::AbstractArray, y::AbstractArray; dims = 1) + # This was previously `∇logsoftmax!(dx, dy, x, y; dims)` to allow CUDA overloads, but that was slow. + dx = dy .- sum(dy; dims) .* exp.(y) end - -function rrule(::typeof(logsoftmax), xs; dims=1) - y = logsoftmax(xs; dims=dims) - logsoftmax_pullback(Δ) = (NoTangent(), ∇logsoftmax(unthunk(Δ), xs, y, dims = dims)) + +function rrule(::typeof(logsoftmax), x; dims = 1) + y = logsoftmax(x; dims) + logsoftmax_pullback(dy) = (NoTangent(), ∇logsoftmax_data(unthunk(dy), y; dims)) return y, logsoftmax_pullback end """ logsumexp(x; dims = :) -Computes `log.(sum(exp.(x); dims = dims))` in a numerically stable -way. +Computes `log.(sum(exp.(x); dims))` in a numerically stable way. +Without `dims` keyword this returns a scalar. See also [`logsoftmax`](@ref). """ function logsumexp(x::AbstractArray; dims = :) - max_ = maximum(x; dims = dims) - max_ .+ log.(sum(exp.(x .- max_); dims = dims)) + max_ = maximum(x; dims) + @fastmath max_ .+ log.(sum(exp.(x .- max_); dims)) +end + +function rrule(::typeof(logsumexp), x; dims = :) + # The gradient is `softmax`, but both compute `tmp` so it's worth saving. + max_ = maximum(x; dims) + @fastmath tmp = exp.(x .- max_) + @fastmath y = max_ .+ log.(sum(tmp; dims)) + logsumexp_pullback(dy) = (NoTangent(), unthunk(dy) .* tmp ./ sum(tmp; dims)) + return y, logsumexp_pullback end # Informative error message if any of the softmax variants is called with a number diff --git a/test/softmax.jl b/test/softmax.jl index 453dee653..1d5105af6 100644 --- a/test/softmax.jl +++ b/test/softmax.jl @@ -1,4 +1,5 @@ using Statistics: mean +using NNlib: ∇softmax_data, ∇logsoftmax_data @testset "softmax integer input" begin @test softmax(Int[0, 0]) == [0.5, 0.5] @@ -34,10 +35,10 @@ end @test logsoftmax(xs) ≈ [-999 -1998 -2997; 0 0 0.0] y = logsoftmax(xs) - @test ∇logsoftmax(ones(Float32, size(xs)), xs, y) ≈ Float32[1 1 1; -1 -1 -1] + @test ∇logsoftmax_data(ones(Float32, size(xs)), y) ≈ Float32[1 1 1; -1 -1 -1] y = softmax(xs) - @test ∇softmax(ones(Float32, size(xs)), xs, y) ≈ zeros(Float32, size(xs)) + @test ∇softmax_data(ones(Float32, size(xs)), y) ≈ zeros(Float32, size(xs)) # These values precalculated using PyTorch's nn.LogSoftmax xs = [ @@ -52,10 +53,10 @@ end ] y = logsoftmax(xs) - @test ∇logsoftmax(ones(size(xs)), xs, y) ≈ ys rtol = 1e-6 + @test ∇logsoftmax_data(ones(size(xs)), y) ≈ ys rtol = 1e-6 y = softmax(xs) - @test ∇softmax(ones(size(xs)), xs, y) ≈ zeros(size(xs)) atol = 1e-6 + @test ∇softmax_data(ones(size(xs)), y) ≈ zeros(size(xs)) atol = 1e-6 end @testset "softmax with Inf, NaN" begin @@ -91,12 +92,12 @@ end @testset "$fn(Float64, $(size(xs)))" for fn in [zeros, ones, rand] Δ = fn(Float64, size(xs)) y = softmax(xs) - ∇softmax!(out, Δ, xs, y) - @test out ≈ ∇softmax(Δ, xs, y) rtol = 1e-6 + ∇softmax!(out, Δ, xs, y) # deprecated + @test out ≈ ∇softmax_data(Δ, y) rtol = 1e-6 y = logsoftmax(xs) - ∇logsoftmax!(out, Δ, xs, y) - @test out ≈ ∇logsoftmax(Δ, xs, y) rtol = 1e-6 + ∇logsoftmax!(out, Δ, xs, y) # deprecated + @test out ≈ ∇logsoftmax_data(Δ, y) rtol = 1e-6 end end end @@ -109,7 +110,6 @@ end @test logsumexp(x; dims = 1) ≈ flogsoft(x, dims = 1) end - @testset "AutoDiff" begin for f in (softmax, logsoftmax), d in (:, 1, 2) gradtest(f, (3,4); fkwargs = (dims = d,), check_rrule = true) @@ -117,6 +117,7 @@ end gradtest(x -> softmax(x) .* (1:3), 3) gradtest(x -> softmax(x) .* (1:3), (3,5), atol = 1e-4) gradtest(x -> softmax(x, dims = 2) .* (1:3), (3,5), atol = 1e-4) + gradtest(x -> logsoftmax(x) .* (1:3), 3) gradtest(x -> logsoftmax(x) .* (1:3), (3,5)) gradtest(x -> logsoftmax(x, dims = 2) .* (1:3), (3,5)) @@ -125,3 +126,15 @@ end gradtest(logsumexp, (3,4), fkwargs = (dims = d,)) end end + +@testset "Second derivatives" begin + x = [1 2 3; 6 5 4] + H = Zygote.hessian_dual(x -> sum(sin, softmax(x)), x) + @test H ≈ Zygote.hessian_reverse(x -> sum(sin, softmax(x)), x) + + H2 = Zygote.hessian_dual(x -> sum(sin, logsoftmax(x)), x) + @test H2 ≈ Zygote.hessian_reverse(x -> sum(sin, logsoftmax(x)), x) + + H3 = Zygote.hessian_dual(x -> sum(sin, logsumexp(x)), x) + @test H3 ≈ Zygote.hessian_reverse(x -> sum(sin, logsumexp(x)), x) +end