From c4c582d1a68fb26cc23b4c736ba5ffb0b3faefea Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 13 Dec 2024 17:40:19 +0100 Subject: [PATCH] fix doctests --- docs/src/guide/models/recurrence.md | 31 +++++++++-------------------- src/layers/recurrent.jl | 3 +++ 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index 3079ab7311..446ff86f82 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -169,13 +169,10 @@ Flux.update!(opt_state, model, g) Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows: -```jldoctest -julia> rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3) -Recurrence( - LSTMCell(2 => 3), # 72 parameters -) # Total: 3 arrays, 72 parameters, 448 bytes. - -julia> y = rnn(rand(Float32, 2, 4, 3)); +```julia +rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3) +x = rand(Float32, 2, 4, 3) +y = rnn(x) ``` ## Stacking recurrent layers @@ -183,20 +180,10 @@ julia> y = rnn(rand(Float32, 2, 4, 3)); Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain. For instance, a model with two LSTM layers can be defined as follows: -```jldoctest -julia> stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5)) -Chain( - LSTM(3 => 5), # 180 parameters - Dropout(0.5), - LSTM(5 => 5), # 220 parameters -) # Total: 6 arrays, 400 parameters, 1.898 KiB. - -julia> x = rand(Float32, 3, 4); - -julia> y = stacked_rnn(x); - -julia> size(y) -(5, 4) +```julia +stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5)) +x = rand(Float32, 3, 4) +y = stacked_rnn(x) ``` If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows: @@ -223,6 +210,6 @@ function (m::StackedRNN)(x) end rnn = StackedRNN(3; num_layers=2) -x = rand(Float32, 3, 2) +x = rand(Float32, 3, 10) y = rnn(x) ``` diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 5b88768ad1..a6a82b0e72 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -53,6 +53,9 @@ stack(out, dims = 2) ```jldoctest julia> rnn = Recurrence(RNNCell(2 => 3)) +Recurrence( + RNNCell(2 => 3, tanh), # 18 parameters +) # Total: 3 arrays, 18 parameters, 232 bytes. julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size