Skip to content

Commit

Permalink
Recurrence layer (#2549)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 13, 2024
1 parent 009d35b commit 6041cf5
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
DATADEPS_ALWAYS_ACCEPT: true

1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
DataFrames, JLD2, MLDataDevices

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

Expand Down
4 changes: 2 additions & 2 deletions docs/src/guide/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
<h3><img src="../../../assets/zygote-crop.png" width="40px"/>&nbsp;<a href="https://github.com/FluxML/Zygote.jl">Zygote.jl</a></h3>
```

Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
hooks into Julia's compiler to find out what operations `f` contains, and transforms this
to produce code for computing `∂f/∂x`.
Expand Down Expand Up @@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
- Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref Flux.glorot_uniform) by default.
- It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
- And it has some performance tricks: making sure element types match, and re-using some memory.
* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
Expand Down
47 changes: 47 additions & 0 deletions docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,50 @@ opt_state = Flux.setup(AdamW(1e-3), model)
g = gradient(m -> Flux.mse(m(x), y), model)[1]
Flux.update!(opt_state, model, g)
```

Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:

```julia
rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
x = rand(Float32, 2, 4, 3)
y = rnn(x)
```

## Stacking recurrent layers

Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
For instance, a model with two LSTM layers can be defined as follows:

```julia
stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
x = rand(Float32, 3, 4)
y = stacked_rnn(x)
```

If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:

```julia
struct StackedRNN{L,S}
layers::L
states0::S
end

Flux.@layer StackedRNN

function StackedRNN(d::Int; num_layers::Int)
layers = [LSTM(d => d) for _ in 1:num_layers]
states0 = [Flux.initialstates(l) for l in layers]
return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
for (layer, state0) in zip(rnn.layers, rnn.states0)
x = layer(x, state0)
end
return x
end

rnn = StackedRNN(3; num_layers=2)
x = rand(Float32, 3, 10)
y = rnn(x)
```
1 change: 1 addition & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PairwiseFusion
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

```@docs
Recurrence
RNNCell
RNN
LSTMCell
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
RNN, LSTM, GRU, GRUv3,
RNN, LSTM, GRU, GRUv3, Recurrence,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
Expand Down
93 changes: 84 additions & 9 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
Expand All @@ -12,7 +11,67 @@ function scan(cell, x, state0)
return stack(y, dims = 2)
end

"""
Recurrence(cell)
Create a recurrent layer that processes entire sequences out
of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
the initial hidden state. The output of the `cell` is considered to be:
1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
# Forward
rnn(x, [state])
The input `x` should be an array of size `in x len` or `in x len x batch_size`,
where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
`Flux.initialstates(cell)`.
The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
The operation performed is semantically equivalent to the following code:
```julia
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]
state = Flux.initialstates(cell)
out = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = [out..., out_from_state(state)]
end
stack(out, dims = 2)
```
# Examples
```jldoctest
julia> rnn = Recurrence(RNNCell(2 => 3))
Recurrence(
RNNCell(2 => 3, tanh), # 18 parameters
) # Total: 3 arrays, 18 parameters, 232 bytes.
julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
julia> y = rnn(x); # out x len x batch_size
```
"""
struct Recurrence{M}
cell::M
end

@layer Recurrence

initialstates(rnn::Recurrence) = initialstates(rnn.cell)

(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -185,9 +244,7 @@ julia> x = rand(Float32, (d_in, len, batch_size));
julia> h = zeros(Float32, (d_out, batch_size));
julia> rnn = RNN(d_in => d_out)
RNN(
RNNCell(4 => 6, tanh), # 66 parameters
) # Total: 3 arrays, 66 parameters, 424 bytes.
RNN(4 => 6, tanh) # 66 parameters
julia> y = rnn(x, h); # [y] = [d_out, len, batch_size]
```
Expand All @@ -212,7 +269,7 @@ struct RNN{M}
cell::M
end

@layer RNN
@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

Expand All @@ -233,6 +290,12 @@ function (m::RNN)(x::AbstractArray, h)
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::RNN)
print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
print(io, ", ", m.cell.σ)
print(io, ")")
end


# LSTM
@doc raw"""
Expand Down Expand Up @@ -401,7 +464,7 @@ struct LSTM{M}
cell::M
end

@layer LSTM
@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

Expand All @@ -417,6 +480,10 @@ function (m::LSTM)(x::AbstractArray, state0)
return scan(m.cell, x, state0)
end

function Base.show(io::IO, m::LSTM)
print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")
end

# GRU

@doc raw"""
Expand Down Expand Up @@ -569,7 +636,7 @@ struct GRU{M}
cell::M
end

@layer GRU
@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

Expand All @@ -585,6 +652,10 @@ function (m::GRU)(x::AbstractArray, h)
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::GRU)
print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
end

# GRU v3
@doc raw"""
GRUv3Cell(in => out; init_kernel = glorot_uniform,
Expand Down Expand Up @@ -729,7 +800,7 @@ struct GRUv3{M}
cell::M
end

@layer GRUv3
@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

Expand All @@ -744,3 +815,7 @@ function (m::GRUv3)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::GRUv3)
print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
end
9 changes: 9 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,12 @@ end
# no initial state same as zero initial state
@test gru(x) gru(x, zeros(Float32, 4))
end

@testset "Recurrence" begin
x = rand(Float32, 2, 3, 4)
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
cell = rnn.cell
rec = Recurrence(cell)
@test rec(x) rnn(x)
end
end

0 comments on commit 6041cf5

Please sign in to comment.