Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recurrence layer #2549

Merged
merged 6 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
DATADEPS_ALWAYS_ACCEPT: true

1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers,
OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics,
DataFrames, JLD2, MLDataDevices

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true)

Expand Down
4 changes: 2 additions & 2 deletions docs/src/guide/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.
<h3><img src="../../../assets/zygote-crop.png" width="40px"/>&nbsp;<a href="https://github.com/FluxML/Zygote.jl">Zygote.jl</a></h3>
```

Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl).
Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)`
hooks into Julia's compiler to find out what operations `f` contains, and transforms this
to produce code for computing `∂f/∂x`.
Expand Down Expand Up @@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before?
Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array.
* Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`:
- Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed.
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default.
- Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref Flux.glorot_uniform) by default.
- It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input.
- And it has some performance tricks: making sure element types match, and re-using some memory.
* The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way,
Expand Down
47 changes: 47 additions & 0 deletions docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,50 @@ opt_state = Flux.setup(AdamW(1e-3), model)
g = gradient(m -> Flux.mse(m(x), y), model)[1]
Flux.update!(opt_state, model, g)
```

Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:

```julia
rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
x = rand(Float32, 2, 4, 3)
y = rnn(x)
```

## Stacking recurrent layers

Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
For instance, a model with two LSTM layers can be defined as follows:

```julia
stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
x = rand(Float32, 3, 4)
y = stacked_rnn(x)
```

If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:

```julia
struct StackedRNN{L,S}
layers::L
states0::S
end

Flux.@layer StackedRNN

function StackedRNN(d::Int; num_layers::Int)
layers = [LSTM(d => d) for _ in 1:num_layers]
states0 = [Flux.initialstates(l) for l in layers]
return StackedRNN(layers, states0)
end

function (m::StackedRNN)(x)
for (layer, state0) in zip(rnn.layers, rnn.states0)
x = layer(x, state0)
end
return x
end

rnn = StackedRNN(3; num_layers=2)
x = rand(Float32, 3, 10)
y = rnn(x)
```
1 change: 1 addition & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ PairwiseFusion
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

```@docs
Recurrence
RNNCell
RNN
LSTMCell
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore
export Chain, Dense, Embedding, EmbeddingBag,
Maxout, SkipConnection, Parallel, PairwiseFusion,
RNNCell, LSTMCell, GRUCell, GRUv3Cell,
RNN, LSTM, GRU, GRUv3,
RNN, LSTM, GRU, GRUv3, Recurrence,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout,
Expand Down
93 changes: 84 additions & 9 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
Expand All @@ -12,7 +11,67 @@
return stack(y, dims = 2)
end

"""
Recurrence(cell)

Create a recurrent layer that processes entire sequences out
of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.

The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
the initial hidden state. The output of the `cell` is considered to be:
1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.

# Forward

rnn(x, [state])

The input `x` should be an array of size `in x len` or `in x len x batch_size`,
where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size.
The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling
`Flux.initialstates(cell)`.

The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell.

The operation performed is semantically equivalent to the following code:
```julia
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

state = Flux.initialstates(cell)
out = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = [out..., out_from_state(state)]
end
stack(out, dims = 2)
```

# Examples

```jldoctest
julia> rnn = Recurrence(RNNCell(2 => 3))
Recurrence(
RNNCell(2 => 3, tanh), # 18 parameters
) # Total: 3 arrays, 18 parameters, 232 bytes.

julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size

julia> y = rnn(x); # out x len x batch_size
```
"""
struct Recurrence{M}
cell::M
end

@layer Recurrence

initialstates(rnn::Recurrence) = initialstates(rnn.cell)

Check warning on line 71 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L71

Added line #L71 was not covered by tests

(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state)

Check warning on line 74 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L73-L74

Added lines #L73 - L74 were not covered by tests

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -184,9 +243,7 @@
julia> h = zeros(Float32, (d_out, batch_size));

julia> rnn = RNN(d_in => d_out)
RNN(
RNNCell(4 => 6, tanh), # 66 parameters
) # Total: 3 arrays, 66 parameters, 424 bytes.
RNN(4 => 6, tanh) # 66 parameters

julia> y = rnn(x, h); # [y] = [d_out, len, batch_size]
```
Expand All @@ -211,7 +268,7 @@
cell::M
end

@layer RNN
@layer :noexpand RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

Expand All @@ -232,6 +289,12 @@
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::RNN)
print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1))
print(io, ", ", m.cell.σ)
print(io, ")")

Check warning on line 295 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L292-L295

Added lines #L292 - L295 were not covered by tests
end


# LSTM
@doc raw"""
Expand Down Expand Up @@ -400,7 +463,7 @@
cell::M
end

@layer LSTM
@layer :noexpand LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

Expand All @@ -416,6 +479,10 @@
return scan(m.cell, x, state0)
end

function Base.show(io::IO, m::LSTM)
print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")")

Check warning on line 483 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L482-L483

Added lines #L482 - L483 were not covered by tests
end

# GRU

@doc raw"""
Expand Down Expand Up @@ -568,7 +635,7 @@
cell::M
end

@layer GRU
@layer :noexpand GRU

initialstates(gru::GRU) = initialstates(gru.cell)

Expand All @@ -584,6 +651,10 @@
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::GRU)
print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")

Check warning on line 655 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L654-L655

Added lines #L654 - L655 were not covered by tests
end

# GRU v3
@doc raw"""
GRUv3Cell(in => out; init_kernel = glorot_uniform,
Expand Down Expand Up @@ -728,7 +799,7 @@
cell::M
end

@layer GRUv3
@layer :noexpand GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

Expand All @@ -743,3 +814,7 @@
@assert ndims(x) == 2 || ndims(x) == 3
return scan(m.cell, x, h)
end

function Base.show(io::IO, m::GRUv3)
print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")")

Check warning on line 819 in src/layers/recurrent.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/recurrent.jl#L818-L819

Added lines #L818 - L819 were not covered by tests
end
9 changes: 9 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,12 @@ end
# no initial state same as zero initial state
@test gru(x) ≈ gru(x, zeros(Float32, 4))
end

@testset "Recurrence" begin
x = rand(Float32, 2, 3, 4)
for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)]
cell = rnn.cell
rec = Recurrence(cell)
@test rec(x) ≈ rnn(x)
end
end
Loading