diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index acd279d709..dfa4a442f2 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -75,4 +75,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
- DATADEPS_ALWAYS_ACCEPT: true
+
diff --git a/docs/make.jl b/docs/make.jl
index 4367639d8e..d74486936b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
DataFrames, JLD2, MLDataDevices
+ENV["DATADEPS_ALWAYS_ACCEPT"] = true
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)
diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md
index 7ad62ee207..7141fb0b28 100644
--- a/docs/src/guide/models/basics.md
+++ b/docs/src/guide/models/basics.md
@@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
```
-Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
+Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
hooks into Julia's compiler to find out what operations `f` contains, and transforms this
to produce code for computing `∂f/∂x`.
@@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
- Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
- - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
+ - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref Flux.glorot_uniform) by default.
- It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
- And it has some performance tricks: making sure element types match, and re-using some memory.
* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md
index 5b2e70f095..446ff86f82 100644
--- a/docs/src/guide/models/recurrence.md
+++ b/docs/src/guide/models/recurrence.md
@@ -166,3 +166,50 @@ opt_state = Flux.setup(AdamW(1e-3), model)
g = gradient(m -> Flux.mse(m(x), y), model)[1]
Flux.update!(opt_state, model, g)
```
+
+Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:
+
+```julia
+rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
+x = rand(Float32, 2, 4, 3)
+y = rnn(x)
+```
+
+## Stacking recurrent layers
+
+Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
+For instance, a model with two LSTM layers can be defined as follows:
+
+```julia
+stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
+x = rand(Float32, 3, 4)
+y = stacked_rnn(x)
+```
+
+If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:
+
+```julia
+struct StackedRNN{L,S}
+ layers::L
+ states0::S
+end
+
+Flux.@layer StackedRNN
+
+function StackedRNN(d::Int; num_layers::Int)
+ layers = [LSTM(d => d) for _ in 1:num_layers]
+ states0 = [Flux.initialstates(l) for l in layers]
+ return StackedRNN(layers, states0)
+end
+
+function (m::StackedRNN)(x)
+ for (layer, state0) in zip(rnn.layers, rnn.states0)
+ x = layer(x, state0)
+ end
+ return x
+end
+
+rnn = StackedRNN(3; num_layers=2)
+x = rand(Float32, 3, 10)
+y = rnn(x)
+```
diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md
index 355d3e7833..562304de70 100644
--- a/docs/src/reference/models/layers.md
+++ b/docs/src/reference/models/layers.md
@@ -104,6 +104,7 @@ PairwiseFusion
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).
```@docs
+Recurrence
RNNCell
RNN
LSTMCell
diff --git a/src/Flux.jl b/src/Flux.jl
index 8fb2351aa2..3a598e88f5 100644
--- a/src/Flux.jl
+++ b/src/Flux.jl
@@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
- RNN, LSTM, GRU, GRUv3,
+ RNN, LSTM, GRU, GRUv3, Recurrence,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl
index 9386a3fc2d..a6a82b0e72 100644
--- a/src/layers/recurrent.jl
+++ b/src/layers/recurrent.jl
@@ -1,8 +1,7 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]
-function scan(cell, x, state0)
- state = state0
+function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
@@ -12,7 +11,67 @@ function scan(cell, x, state0)
return stack(y, dims = 2)
end
+"""
+ Recurrence(cell)
+
+Create a recurrent layer that processes entire sequences out
+of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
+similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
+
+The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
+a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
+the initial hidden state. The output of the `cell` is considered to be:
+1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
+2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
+
+# Forward
+
+ rnn(x, [state])
+
+The input `x` should be an array of size `in x len` or `in x len x batch_size`,
+where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
+The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
+`Flux.initialstates(cell)`.
+
+The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
+
+The operation performed is semantically equivalent to the following code:
+```julia
+out_from_state(state) = state
+out_from_state(state::Tuple) = state[1]
+state = Flux.initialstates(cell)
+out = []
+for x_t in eachslice(x, dims = 2)
+ state = cell(x_t, state)
+ out = [out..., out_from_state(state)]
+end
+stack(out, dims = 2)
+```
+
+# Examples
+
+```jldoctest
+julia> rnn = Recurrence(RNNCell(2 => 3))
+Recurrence(
+ RNNCell(2 => 3, tanh), # 18 parameters
+) # Total: 3 arrays, 18 parameters, 232 bytes.
+
+julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
+
+julia> y = rnn(x); # out x len x batch_size
+```
+"""
+struct Recurrence{M}
+ cell::M
+end
+
+@layer Recurrence
+
+initialstates(rnn::Recurrence) = initialstates(rnn.cell)
+
+(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
+(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)
# Vanilla RNN
@doc raw"""
@@ -184,9 +243,7 @@ julia> x = rand(Float32, (d_in, len, batch_size));
julia> h = zeros(Float32, (d_out, batch_size));
julia> rnn = RNN(d_in => d_out)
-RNN(
- RNNCell(4 => 6, tanh), # 66 parameters
-) # Total: 3 arrays, 66 parameters, 424 bytes.
+RNN(4 => 6, tanh) # 66 parameters
julia> y = rnn(x, h); # [y] = [d_out, len, batch_size]
```
@@ -211,7 +268,7 @@ struct RNN{M}
cell::M
end
-@layer RNN
+@layer :noexpand RNN
initialstates(rnn::RNN) = initialstates(rnn.cell)
@@ -232,6 +289,12 @@ function (m::RNN)(x::AbstractArray, h)
return scan(m.cell, x, h)
end
+function Base.show(io::IO, m::RNN)
+ print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
+ print(io, ", ", m.cell.σ)
+ print(io, ")")
+end
+
# LSTM
@doc raw"""
@@ -400,7 +463,7 @@ struct LSTM{M}
cell::M
end
-@layer LSTM
+@layer :noexpand LSTM
initialstates(lstm::LSTM) = initialstates(lstm.cell)
@@ -416,6 +479,10 @@ function (m::LSTM)(x::AbstractArray, state0)
return scan(m.cell, x, state0)
end
+function Base.show(io::IO, m::LSTM)
+ print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")
+end
+
# GRU
@doc raw"""
@@ -568,7 +635,7 @@ struct GRU{M}
cell::M
end
-@layer GRU
+@layer :noexpand GRU
initialstates(gru::GRU) = initialstates(gru.cell)
@@ -584,6 +651,10 @@ function (m::GRU)(x::AbstractArray, h)
return scan(m.cell, x, h)
end
+function Base.show(io::IO, m::GRU)
+ print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
+end
+
# GRU v3
@doc raw"""
GRUv3Cell(in => out; init_kernel = glorot_uniform,
@@ -728,7 +799,7 @@ struct GRUv3{M}
cell::M
end
-@layer GRUv3
+@layer :noexpand GRUv3
initialstates(gru::GRUv3) = initialstates(gru.cell)
@@ -743,3 +814,7 @@ function (m::GRUv3)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
end
+
+function Base.show(io::IO, m::GRUv3)
+ print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
+end
\ No newline at end of file
diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl
index 864e5dad8e..3ad7428601 100644
--- a/test/layers/recurrent.jl
+++ b/test/layers/recurrent.jl
@@ -305,3 +305,12 @@ end
# no initial state same as zero initial state
@test gru(x) ≈ gru(x, zeros(Float32, 4))
end
+
+@testset "Recurrence" begin
+ x = rand(Float32, 2, 3, 4)
+ for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
+ cell = rnn.cell
+ rec = Recurrence(cell)
+ @test rec(x) ≈ rnn(x)
+ end
+end