Skip to content

Commit

Permalink
fix doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 13, 2024
1 parent 89fc89f commit c4c582d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
31 changes: 9 additions & 22 deletions docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,34 +169,21 @@ Flux.update!(opt_state, model, g)

Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows:

```jldoctest
julia> rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
Recurrence(
LSTMCell(2 => 3), # 72 parameters
) # Total: 3 arrays, 72 parameters, 448 bytes.
julia> y = rnn(rand(Float32, 2, 4, 3));
```julia
rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3)
x = rand(Float32, 2, 4, 3)
y = rnn(x)
```

## Stacking recurrent layers

Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain.
For instance, a model with two LSTM layers can be defined as follows:

```jldoctest
julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
Chain(
LSTM(3 => 5), # 180 parameters
Dropout(0.5),
LSTM(5 => 5), # 220 parameters
) # Total: 6 arrays, 400 parameters, 1.898 KiB.
julia> x = rand(Float32, 3, 4);
julia> y = stacked_rnn(x);
julia> size(y)
(5, 4)
```julia
stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5))
x = rand(Float32, 3, 4)
y = stacked_rnn(x)
```

If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows:
Expand All @@ -223,6 +210,6 @@ function (m::StackedRNN)(x)
end

rnn = StackedRNN(3; num_layers=2)
x = rand(Float32, 3, 2)
x = rand(Float32, 3, 10)
y = rnn(x)
```
3 changes: 3 additions & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ stack(out, dims = 2)
```jldoctest
julia> rnn = Recurrence(RNNCell(2 => 3))
Recurrence(
RNNCell(2 => 3, tanh), # 18 parameters
) # Total: 3 arrays, 18 parameters, 232 bytes.
julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size
Expand Down

0 comments on commit c4c582d

Please sign in to comment.