Skip to content


Recurrence layer (#2549)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 13, 2024
1 parent 009d35b commit 6041cf5
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ jobs:

1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
DataFrames, JLD2, MLDataDevices


DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

Expand Down
4 changes: 2 additions & 2 deletions docs/src/guide/models/
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
<h3><img src="../../../assets/zygote-crop.png" width="40px"/>&nbsp;<a href="">Zygote.jl</a></h3>

Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](
Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](
Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
hooks into Julia's compiler to find out what operations `f` contains, and transforms this
to produce code for computing `∂f/∂x`.
Expand Down Expand Up @@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
- Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref Flux.glorot_uniform) by default.
- It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
- And it has some performance tricks: making sure element types match, and re-using some memory.
* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
Expand Down
47 changes: 47 additions & 0 deletions docs/src/guide/models/
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,50 @@ opt_state = Flux.setup(AdamW(1e-3), model)
g = gradient(m -> Flux.mse(m(x), y), model)[1]
Flux.update!(opt_state, model, g)

Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:

rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
x = rand(Float32, 2, 4, 3)
y = rnn(x)

## Stacking recurrent layers

Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
For instance, a model with two LSTM layers can be defined as follows:

stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
x = rand(Float32, 3, 4)
y = stacked_rnn(x)

If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:

struct StackedRNN{L,S}

Flux.@layer StackedRNN

function StackedRNN(d::Int; num_layers::Int)
layers = [LSTM(d => d) for _ in 1:num_layers]
states0 = [Flux.initialstates(l) for l in layers]
return StackedRNN(layers, states0)

function (m::StackedRNN)(x)
for (layer, state0) in zip(rnn.layers, rnn.states0)
x = layer(x, state0)
return x

rnn = StackedRNN(3; num_layers=2)
x = rand(Float32, 3, 10)
y = rnn(x)
1 change: 1 addition & 0 deletions docs/src/reference/models/
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PairwiseFusion
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
RNN, LSTM, GRU, GRUv3, Recurrence,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
Expand Down
93 changes: 84 additions & 9 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
Expand All @@ -12,7 +11,67 @@ function scan(cell, x, state0)
return stack(y, dims = 2)

Create a recurrent layer that processes entire sequences out
of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
the initial hidden state. The output of the `cell` is considered to be:
1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
# Forward
rnn(x, [state])
The input `x` should be an array of size `in x len` or `in x len x batch_size`,
where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
The operation performed is semantically equivalent to the following code:
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]
state = Flux.initialstates(cell)
out = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = [out..., out_from_state(state)]
stack(out, dims = 2)
# Examples
julia> rnn = Recurrence(RNNCell(2 => 3))
RNNCell(2 => 3, tanh), # 18 parameters
) # Total: 3 arrays, 18 parameters, 232 bytes.
julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
julia> y = rnn(x); # out x len x batch_size
struct Recurrence{M}

@layer Recurrence

initialstates(rnn::Recurrence) = initialstates(rnn.cell)

(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -185,9 +244,7 @@ julia> x = rand(Float32, (d_in, len, batch_size));
julia> h = zeros(Float32, (d_out, batch_size));
julia> rnn = RNN(d_in => d_out)
RNNCell(4 => 6, tanh), # 66 parameters
) # Total: 3 arrays, 66 parameters, 424 bytes.
RNN(4 => 6, tanh) # 66 parameters
julia> y = rnn(x, h); # [y] = [d_out, len, batch_size]
Expand All @@ -212,7 +269,7 @@ struct RNN{M}

@layer RNN
@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

Expand All @@ -233,6 +290,12 @@ function (m::RNN)(x::AbstractArray, h)
return scan(m.cell, x, h)

function, m::RNN)
print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
print(io, ", ", m.cell.σ)
print(io, ")")

@doc raw"""
Expand Down Expand Up @@ -401,7 +464,7 @@ struct LSTM{M}

@layer LSTM
@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

Expand All @@ -417,6 +480,10 @@ function (m::LSTM)(x::AbstractArray, state0)
return scan(m.cell, x, state0)

function, m::LSTM)
print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")


@doc raw"""
Expand Down Expand Up @@ -569,7 +636,7 @@ struct GRU{M}

@layer GRU
@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

Expand All @@ -585,6 +652,10 @@ function (m::GRU)(x::AbstractArray, h)
return scan(m.cell, x, h)

function, m::GRU)
print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")

# GRU v3
@doc raw"""
GRUv3Cell(in => out; init_kernel = glorot_uniform,
Expand Down Expand Up @@ -729,7 +800,7 @@ struct GRUv3{M}

@layer GRUv3
@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

Expand All @@ -744,3 +815,7 @@ function (m::GRUv3)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)

function, m::GRUv3)
print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
9 changes: 9 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,12 @@ end
# no initial state same as zero initial state
@test gru(x) gru(x, zeros(Float32, 4))

@testset "Recurrence" begin
x = rand(Float32, 2, 3, 4)
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
cell = rnn.cell
rec = Recurrence(cell)
@test rec(x) rnn(x)

0 comments on commit 6041cf5

Please sign in to comment.