Skip to content

Commit

Permalink
recurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 13, 2024
1 parent b8ed5d4 commit 098c064
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 13 deletions.
60 changes: 60 additions & 0 deletions docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,63 @@ opt_state = Flux.setup(AdamW(1e-3), model)
g = gradient(m -> Flux.mse(m(x), y), model)[1]
Flux.update!(opt_state, model, g)
```

Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:

```jldoctest
julia> rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
Recurrence(
LSTMCell(2 => 3), # 72 parameters
) # Total: 3 arrays, 72 parameters, 448 bytes.
julia> y = rnn(rand(Float32, 2, 4, 3));
```

## Stacking recurrent layers

Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
For instance, a model with two LSTM layers can be defined as follows:

```jldoctest
julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
Chain(
LSTM(3 => 5), # 180 parameters
Dropout(0.5),
LSTM(5 => 5), # 220 parameters
) # Total: 6 arrays, 400 parameters, 1.898 KiB.
julia> x = rand(Float32, 3, 4);
julia> y = stacked_rnn(x);
julia> size(y)
(5, 4)
```

If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:

```julia
struct StackedRNN{L,S}
layers::L
states0::S
end

Flux.@layer StackedRNN

function StackedRNN(d::Int; num_layers::Int)
layers = [LSTM(d => d) for _ in 1:num_layers]
states0 = [Flux.initialstates(l) for l in layers]
return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
for (layer, state0) in zip(rnn.layers, rnn.states0)
x = layer(x, state0)
end
return x
end

rnn = StackedRNN(3; num_layers=2)
x = rand(Float32, 3, 2)
y = rnn(x)
```
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
RNN, LSTM, GRU, GRUv3,
RNN, LSTM, GRU, GRUv3, Recurrence,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
Expand Down
59 changes: 47 additions & 12 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,47 @@ the initial hidden state. The output of the `cell` is considered to be:
rnn(x, [state])
The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`,
where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
The input `x` should be an array of size `in x len` or `in x len x batch_size`,
where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
`Flux.initialstates(cell)`.
The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.
The operation performed is semantically equivalent to the following code:
```julia
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]
state = Flux.initialstates(cell)
out = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = [out..., get_output(state)]
out = [out..., out_from_state(state)]
end
stack(out, dims = 2)
```
# Examples
```jldoctest
julia> rnn = Recurrent(RNNCell(2 => 3))
julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
julia> y = rnn(x); # out x len x batch_size
```
"""
struct Recurrent{M}
struct Recurrence{M}
cell::M
end

@layer Recurrent
@layer Recurrence

initialstates(rnn::Recurrent) = initialstates(rnn.cell)
initialstates(rnn::Recurrence) = initialstates(rnn.cell)

(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)
(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -250,7 +267,7 @@ struct RNN{M}
cell::M
end

@layer RNN
@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

Expand All @@ -271,6 +288,12 @@ function (m::RNN)(x::AbstractArray, h)
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::RNN)
print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
print(io, ", ", m.cell.σ)
print(io, ")")
end


# LSTM
@doc raw"""
Expand Down Expand Up @@ -439,7 +462,7 @@ struct LSTM{M}
cell::M
end

@layer LSTM
@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

Expand All @@ -455,6 +478,10 @@ function (m::LSTM)(x::AbstractArray, state0)
return scan(m.cell, x, state0)
end

function Base.show(io::IO, m::LSTM)
print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")
end

# GRU

@doc raw"""
Expand Down Expand Up @@ -607,7 +634,7 @@ struct GRU{M}
cell::M
end

@layer GRU
@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

Expand All @@ -623,6 +650,10 @@ function (m::GRU)(x::AbstractArray, h)
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::GRU)
print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
end

# GRU v3
@doc raw"""
GRUv3Cell(in => out; init_kernel = glorot_uniform,
Expand Down Expand Up @@ -767,7 +798,7 @@ struct GRUv3{M}
cell::M
end

@layer GRUv3
@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

Expand All @@ -782,3 +813,7 @@ function (m::GRUv3)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::GRUv3)
print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")
end
9 changes: 9 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,12 @@ end
# no initial state same as zero initial state
@test gru(x) gru(x, zeros(Float32, 4))
end

@testset "Recurrence" begin
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
cell = rnn.cell
rec = Recurrence(cell)
x = rand(Float32, 2, 3, 4)
@test rec(x) rnn(x)
end
end

0 comments on commit 098c064

Please sign in to comment.