diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index dbef6f9932..2a80ee2939 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -21,7 +21,6 @@ In the forward pass, implements the function ```math h^\prime = \sigma(W_i x + W_h h + b) ``` -Returns a tuple `(out, state)`, where both element are given by `h'`. See [`RNN`](@ref) for a layer that processes entire sequences. @@ -43,7 +42,8 @@ The arguments of the forward pass are: - `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns a tuple `(out, state)`, where both elements are given by the updated state `h'`. +Returns a tuple `(output, state)`, where both elements are given by the updated state `h'`, +a tensor of size `out` or `out x batch_size`. # Examples @@ -93,13 +93,13 @@ using Flux rnn = RNNCell(10 => 20) # Get the initial hidden state -h0 = initialstates(rnn) +state = initialstates(rnn) # Get some input data x = rand(Float32, 10) # Run forward -res = rnn(x, h0) +out, state = rnn(x, state) """ initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) @@ -242,7 +242,6 @@ o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o) h_t = o_t \odot \tanh(c_t) ``` -The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step. See also [`LSTM`](@ref) for a layer that processes entire sequences. # Arguments @@ -263,7 +262,8 @@ The arguments of the forward pass are: They should be vectors of size `out` or matrices of size `out x batch_size`. If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). -Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`. +Returns a tuple `(output, state)`, where `output = h'` is the new hidden state and `state = (h', c')` is the new hidden and cell states. +These are tensors of size `out` or `out x batch_size`. # Examples @@ -277,9 +277,9 @@ julia> c = zeros(Float32, 5); # cell state julia> x = rand(Float32, 3, 4); # in x batch_size -julia> h′, c′ = l(x, (h, c)); +julia> y, (h′, c′) = l(x, (h, c)); -julia> size(h′) # out x batch_size +julia> size(y) # out x batch_size (5, 4) ``` """ @@ -445,7 +445,8 @@ The arguments of the forward pass are: - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. +Returns the tuple `(output, state)`, where `output = h'` and `state = h'`. +The new hidden state `h'` is an array of size `out` or `out x batch_size`. # Examples @@ -457,7 +458,7 @@ julia> h = zeros(Float32, 5); # hidden state julia> x = rand(Float32, 3, 4); # in x batch_size -julia> h′ = g(x, h); +julia> y, h = g(x, h); ``` """ struct GRUCell{I, H, V} @@ -612,7 +613,8 @@ The arguments of the forward pass are: - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. +Returns the tuple `(output, state)`, where `output = h'` and `state = h'`. +The new hidden state `h'` is an array of size `out` or `out x batch_size`. """ struct GRUv3Cell{I, H, V, HH} Wi::I