diff --git a/NEWS.md b/NEWS.md index 2a40a64aec..a3558a99a3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,10 +2,17 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release. -## v0.15.3 -* Add `WeightNorm` normalization layer. +## v0.16.0 (15 December 2025) +This release has a single **breaking change**: -## v0.15.0 (December 2024) +- The recurrent cells `RNNCell`, `LSTMCell`, and `GRUCell` forward has been changed to + $y_t, state_t = cell(x_t, state_{t-1})$. Previously, it was $state_t = cell(x_t, state_{t-1})$. + +Other highlights include: +* Added `WeightNorm` normalization layer. +* Added `Recurrence` layer, turning a recurrent layer into a layer processing the entire sequence at once. + +## v0.15.0 (5 December 2024) This release includes two **breaking changes**: - The recurrent layers have been thoroughly revised. See below and read the [documentation](https://fluxml.ai/Flux.jl/v0.15/guide/models/recurrence/) for details. - Flux now defines and exports its own gradient function. Consequently, using gradient in an unqualified manner (e.g., after `using Flux, Zygote`) could result in an ambiguity error. diff --git a/Project.toml b/Project.toml index 25c38c4a8f..3e5224889d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Flux" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.15.2" +version = "0.16.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index 446ff86f82..44834e9a8a 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -21,7 +21,7 @@ b = zeros(Float32, output_size) function rnn_cell(x, h) h = tanh.(Wxh * x .+ Whh * h .+ b) - return h + return h, h end seq_len = 3 @@ -33,14 +33,14 @@ h0 = zeros(Float32, output_size) y = [] ht = h0 for xt in x - ht = rnn_cell(xt, ht) - y = [y; [ht]] # concatenate in non-mutating (AD friendly) way + yt, ht = rnn_cell(xt, ht) + y = [y; [yt]] # concatenate in non-mutating (AD friendly) way end ``` Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`. - -The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model. +The result of the forward pass at each time step, is a tuple contening the output `yt` and the updated state `ht`. The updated state is used as an input in next iteration. In the simple case of a vanilla RNN, the +output and the state are the same. In more complex cells, such as `LSTMCell`, the state can contain multiple arrays. There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with: @@ -58,8 +58,8 @@ rnn_cell = Flux.RNNCell(input_size => output_size) y = [] ht = h0 for xt in x - ht = rnn_cell(xt, ht) - y = [y; [ht]] + yt, ht = rnn_cell(xt, ht) + y = [y; [yt]] end ``` The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression. @@ -78,7 +78,7 @@ struct RecurrentCellModel{H,C,D} end # we choose to not train the initial hidden state -Flux.@layer RecurrentCellModel trainable=(cell,dense) +Flux.@layer RecurrentCellModel trainable=(cell, dense) function RecurrentCellModel(input_size::Int, hidden_size::Int) return RecurrentCellModel( @@ -91,8 +91,8 @@ function (m::RecurrentCellModel)(x) z = [] ht = m.h0 for xt in x - ht = m.cell(xt, ht) - z = [z; [ht]] + yt, ht = m.cell(xt, ht) + z = [z; [yt]] end z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len] ŷ = m.dense(z) # [1, seq_len, batch_size] or [1, seq_len] @@ -109,7 +109,6 @@ using Optimisers: AdamW function loss(model, x, y) ŷ = model(x) - y = stack(y, dims=2) return Flux.mse(ŷ, y) end @@ -123,7 +122,7 @@ model = RecurrentCellModel(input_size, 5) opt_state = Flux.setup(AdamW(1e-3), model) # compute the gradient and update the model -g = gradient(m -> loss(m, x, y),model)[1] +g = gradient(m -> loss(m, x, y), model)[1] Flux.update!(opt_state, model, g) ``` diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 596528b1e7..db1384634f 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,12 +1,8 @@ -out_from_state(state) = state -out_from_state(state::Tuple) = state[1] - function scan(cell, x, state) y = [] for x_t in eachslice(x, dims = 2) - state = cell(x_t, state) - out = out_from_state(state) - y = vcat(y, [out]) + yt, state = cell(x_t, state) + y = vcat(y, [yt]) end return stack(y, dims = 2) end @@ -85,7 +81,6 @@ In the forward pass, implements the function ```math h^\prime = \sigma(W_i x + W_h h + b) ``` -and returns `h'`. See [`RNN`](@ref) for a layer that processes entire sequences. @@ -107,6 +102,9 @@ The arguments of the forward pass are: - `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). +Returns a tuple `(output, state)`, where both elements are given by the updated state `h'`, +a tensor of size `out` or `out x batch_size`. + # Examples ```julia @@ -123,10 +121,10 @@ h = zeros(Float32, 5) ŷ = [] for x_t in x - h = r(x_t, h) - ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation - # is not automatic differentiation friendly yet. - # Can use `y = vcat(y, [h])` as an alternative. + yt, h = r(x_t, h) + ŷ = [ŷ..., yt] # Cannot use `push!(ŷ, h)` here since mutation + # is not automatic differentiation friendly yet. + # Can use `y = vcat(y, [h])` as an alternative. end h # The final hidden state @@ -155,40 +153,37 @@ using Flux rnn = RNNCell(10 => 20) # Get the initial hidden state -h0 = initialstates(rnn) +state = initialstates(rnn) # Get some input data x = rand(Float32, 10) # Run forward -res = rnn(x, h0) +out, state = rnn(x, state) ``` """ initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) function RNNCell( - (in, out)::Pair, - σ = tanh; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true, -) + (in, out)::Pair, + σ = tanh; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, + ) Wi = init_kernel(out, in) Wh = init_recurrent_kernel(out, out) b = create_bias(Wi, bias, size(Wi, 1)) return RNNCell(σ, Wi, Wh, b) end -function (rnn::RNNCell)(x::AbstractVecOrMat) - state = initialstates(rnn) - return rnn(x, state) -end +(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn)) function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi, 2)) σ = NNlib.fast_act(m.σ, x) h = σ.(m.Wi * x .+ m.Wh * h .+ m.bias) - return h + return h, h end function Base.show(io::IO, m::RNNCell) @@ -278,10 +273,7 @@ function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) return RNN(cell) end -function (rnn::RNN)(x::AbstractArray) - state = initialstates(rnn) - return rnn(x, state) -end +(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn)) function (m::RNN)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 @@ -315,7 +307,6 @@ o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) h_t = o_t \odot \tanh(c_t) ``` -The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step. See also [`LSTM`](@ref) for a layer that processes entire sequences. # Arguments @@ -336,7 +327,8 @@ The arguments of the forward pass are: They should be vectors of size `out` or matrices of size `out x batch_size`. If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). -Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`. +Returns a tuple `(output, state)`, where `output = h'` is the new hidden state and `state = (h', c')` is the new hidden and cell states. +These are tensors of size `out` or `out x batch_size`. # Examples @@ -350,9 +342,9 @@ julia> c = zeros(Float32, 5); # cell state julia> x = rand(Float32, 3, 4); # in x batch_size -julia> h′, c′ = l(x, (h, c)); +julia> y, (h′, c′) = l(x, (h, c)); -julia> size(h′) # out x batch_size +julia> size(y) # out x batch_size (5, 4) ``` """ @@ -389,9 +381,9 @@ function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) b = m.bias g = m.Wi * x .+ m.Wh * h .+ b input, forget, cell, output = chunk(g, 4; dims = 1) - c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) - h′ = @. sigmoid_fast(output) * tanh_fast(c′) - return h′, c′ + c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) + h = @. sigmoid_fast(output) * tanh_fast(c) + return h, (h, c) end Base.show(io::IO, m::LSTMCell) = @@ -522,7 +514,8 @@ The arguments of the forward pass are: - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. +Returns the tuple `(output, state)`, where `output = h'` and `state = h'`. +The new hidden state `h'` is an array of size `out` or `out x batch_size`. # Examples @@ -534,7 +527,7 @@ julia> h = zeros(Float32, 5); # hidden state julia> x = rand(Float32, 3, 4); # in x batch_size -julia> h′ = g(x, h); +julia> y, h = g(x, h); ``` """ struct GRUCell{I, H, V} @@ -577,8 +570,8 @@ function (m::GRUCell)(x::AbstractVecOrMat, h) r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3]) - h′ = @. (1 - z) * h̃ + z * h - return h′ + h = @. (1 - z) * h̃ + z * h + return h, h end Base.show(io::IO, m::GRUCell) = @@ -693,7 +686,8 @@ The arguments of the forward pass are: - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. +Returns the tuple `(output, state)`, where `output = h'` and `state = h'`. +The new hidden state `h'` is an array of size `out` or `out x batch_size`. """ struct GRUv3Cell{I, H, V, HH} Wi::I @@ -736,8 +730,8 @@ function (m::GRUv3Cell)(x::AbstractVecOrMat, h) r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) h̃ = tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3]) - h′ = @. (1 - z) * h̃ + z * h - return h′ + h = @. (1 - z) * h̃ + z * h + return h, h end Base.show(io::IO, m::GRUv3Cell) = diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index 3046f1f1fa..35fa983806 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -1,11 +1,9 @@ -out_from_state(state::Tuple) = state[1] -out_from_state(state) = state +cell_loss(cell, x, state) = mean(cell(x, state)[1]) function recurrent_cell_loss(cell, seq, state) out = [] for xt in seq - state = cell(xt, state) - yt = out_from_state(state) + yt, state = cell(xt, state) out = vcat(out, [yt]) end return mean(stack(out, dims = 2)) @@ -18,7 +16,8 @@ end h = zeros(Float32, d_out) # Single Step @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS + compare_finite_diff=false, + loss=cell_loss) broken = :rnncell_single ∈ BROKEN_TESTS # Multiple Steps @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, @@ -51,7 +50,7 @@ end c = zeros(Float32, d_out) # Single Step @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false, - loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS + loss = cell_loss) broken = :lstmcell_single ∈ BROKEN_TESTS # Multiple Steps @test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff = false, @@ -84,7 +83,9 @@ end r = GRUCell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) - @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS + @test test_gradients(r, x[1], h; test_gpu=true, + compare_finite_diff=false, + loss = cell_loss) broken = :grucell_single ∈ BROKEN_TESTS @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff = false, loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS @@ -116,7 +117,8 @@ end x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS + compare_finite_diff=false, + loss=cell_loss) broken = :gruv3cell_single ∈ BROKEN_TESTS @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 3ad7428601..73dae4ac65 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,36 +1,37 @@ - -@testset "RNNCell" begin - function loss1(r, x, h) - for x_t in x - h = r(x_t, h) - end - return mean(h.^2) +function cell_loss1(r, x, state) + for x_t in x + _, state = r(x_t, state) end + return mean(state[1]) +end - function loss2(r, x, h) - y = [r(x_t, h) for x_t in x] - return sum(mean, y) - end +function cell_loss2(r, x, state) + y = [r(x_t, state)[1] for x_t in x] + return sum(mean, y) +end - function loss3(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = [y..., h] - end - return sum(mean, y) +function cell_loss3(r, x, state) + y = [] + for x_t in x + y_t, state = r(x_t, state) + y = [y..., y_t] end + return sum(mean, y) +end - function loss4(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y.^2) +function cell_loss4(r, x, state) + y = [] + for x_t in x + y_t, state = r(x_t, state) + y = vcat(y, [y_t]) end + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y.^2) +end + +@testset "RNNCell" begin + r = RNNCell(3 => 5) @test length(Flux.trainables(r)) == 3 # An input sequence of length 6 and batch size 4. @@ -38,31 +39,33 @@ # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h, loss=loss1) # for loop - test_gradients(r, x, h, loss=loss2) # comprehension - test_gradients(r, x, h, loss=loss3) # splat - test_gradients(r, x, h, loss=loss4) # vcat and stack + test_gradients(r, x, h, loss=cell_loss1) # for loop + test_gradients(r, x, h, loss=cell_loss2) # comprehension + test_gradients(r, x, h, loss=cell_loss3) # splat + test_gradients(r, x, h, loss=cell_loss4) # vcat and stack # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + out, state = r(x[1]) + @test out === state + @test out ≈ r(x[1], zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h, loss=loss4) + test_gradients(r, x, h, loss=cell_loss4) # The input sequence has no batch dimension. x = [rand(Float32, 3) for _ in 1:6] h = randn(Float32, 5) - test_gradients(r, x, h, loss=loss4) + test_gradients(r, x, h, loss=cell_loss4) # No Bias r = RNNCell(3 => 5, bias=false) @test length(Flux.trainables(r)) == 2 - test_gradients(r, x, h, loss=loss4) + test_gradients(r, x, h, loss=cell_loss4) end @testset "RNN" begin @@ -99,32 +102,20 @@ end @testset "LSTMCell" begin - function loss(r, x, hc) - h, c = hc - h′ = [] - c′ = [] - for x_t in x - h, c = r(x_t, (h, c)) - h′ = vcat(h′, [h]) - c′ = [c′..., c] - end - hnew = stack(h′, dims=2) - cnew = stack(c′, dims=2) - return mean(hnew.^2) + mean(cnew.^2) - end - cell = LSTMCell(3 => 5) @test length(Flux.trainables(cell)) == 3 x = [rand(Float32, 3, 4) for _ in 1:6] h = zeros(Float32, 5, 4) c = zeros(Float32, 5, 4) - hnew, cnew = cell(x[1], (h, c)) + out, state = cell(x[1], (h, c)) + hnew, cnew = state + @test out === hnew @test hnew isa Matrix{Float32} @test cnew isa Matrix{Float32} @test size(hnew) == (5, 4) @test size(cnew) == (5, 4) test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) - test_gradients(cell, x, (h, c), loss = loss) + test_gradients(cell, x, (h, c), loss = cell_loss4) # initial states are zero h0, c0 = Flux.initialstates(cell) @@ -132,8 +123,8 @@ end @test c0 ≈ zeros(Float32, 5) # no initial state same as zero initial state - hnew1, cnew1 = cell(x[1]) - hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) + _, (hnew1, cnew1) = cell(x[1]) + _, (hnew2, cnew2) = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) @test hnew1 ≈ hnew2 @test cnew1 ≈ cnew2 @@ -184,16 +175,6 @@ end end @testset "GRUCell" begin - function loss(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y.^2) - end - r = GRUCell(3 => 5) @test length(Flux.trainables(r)) == 3 # An input sequence of length 6 and batch size 4. @@ -201,22 +182,24 @@ end # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h; loss) + out, state = r(x[1], h) + @test out === state + test_gradients(r, x, h; loss = cell_loss4) # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + @test r(x[1])[1] ≈ r(x[1], zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h; loss) + test_gradients(r, x, h; loss = cell_loss4) # The input sequence has no batch dimension. x = [rand(Float32, 3) for _ in 1:6] h = randn(Float32, 5) - test_gradients(r, x, h; loss) + test_gradients(r, x, h; loss = cell_loss4) # No Bias r = GRUCell(3 => 5, bias=false) @@ -262,22 +245,24 @@ end # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h) + out, state = r(x, h) + @test out === state + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x) ≈ r(x, zeros(Float32, 5)) + @test r(x)[1] ≈ r(x, zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) # The input sequence has no batch dimension. x = rand(Float32, 3) h = randn(Float32, 5) - test_gradients(r, x, h) + test_gradients(r, x, h, loss = (m, x, h) -> mean(m(x, h)[1])) end @testset "GRUv3" begin