Skip to content

Commit

Permalink
docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 12, 2024
1 parent 96e2b5e commit 5da71b8
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ In the forward pass, implements the function
```math
h^\prime = \sigma(W_i x + W_h h + b)
```
Returns a tuple `(out, state)`, where both element are given by `h'`.
See [`RNN`](@ref) for a layer that processes entire sequences.
Expand All @@ -43,7 +42,8 @@ The arguments of the forward pass are:
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(out, state)`, where both elements are given by the updated state `h'`.
Returns a tuple `(output, state)`, where both elements are given by the updated state `h'`,
a tensor of size `out` or `out x batch_size`.
# Examples
Expand Down Expand Up @@ -93,13 +93,13 @@ using Flux
rnn = RNNCell(10 => 20)
# Get the initial hidden state
h0 = initialstates(rnn)
state = initialstates(rnn)
# Get some input data
x = rand(Float32, 10)
# Run forward
res = rnn(x, h0)
out, state = rnn(x, state)
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

Expand Down Expand Up @@ -242,7 +242,6 @@ o_t = \sigma(W_{xo} x_t + W_{ho} h_{t-1} + b_o)
h_t = o_t \odot \tanh(c_t)
```
The `LSTMCell` returns the new hidden state `h_t` and cell state `c_t` for a single time step.
See also [`LSTM`](@ref) for a layer that processes entire sequences.
# Arguments
Expand All @@ -263,7 +262,8 @@ The arguments of the forward pass are:
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
Returns a tuple `(output, state)`, where `output = h'` is the new hidden state and `state = (h', c')` is the new hidden and cell states.
These are tensors of size `out` or `out x batch_size`.
# Examples
Expand All @@ -277,9 +277,9 @@ julia> c = zeros(Float32, 5); # cell state
julia> x = rand(Float32, 3, 4); # in x batch_size
julia> h′, c′ = l(x, (h, c));
julia> y, (h′, c′) = l(x, (h, c));
julia> size(h′) # out x batch_size
julia> size(y) # out x batch_size
(5, 4)
```
"""
Expand Down Expand Up @@ -445,7 +445,8 @@ The arguments of the forward pass are:
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
Returns the tuple `(output, state)`, where `output = h'` and `state = h'`.
The new hidden state `h'` is an array of size `out` or `out x batch_size`.
# Examples
Expand All @@ -457,7 +458,7 @@ julia> h = zeros(Float32, 5); # hidden state
julia> x = rand(Float32, 3, 4); # in x batch_size
julia> h′ = g(x, h);
julia> y, h = g(x, h);
```
"""
struct GRUCell{I, H, V}
Expand Down Expand Up @@ -612,7 +613,8 @@ The arguments of the forward pass are:
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
Returns the tuple `(output, state)`, where `output = h'` and `state = h'`.
The new hidden state `h'` is an array of size `out` or `out x batch_size`.
"""
struct GRUv3Cell{I, H, V, HH}
Wi::I
Expand Down

0 comments on commit 5da71b8

Please sign in to comment.