Skip to content

Commit

Permalink
added initialstates to recurrent layers, added docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 7, 2024
1 parent 9e1b1bb commit a9bc95f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ GRUCell
GRU
GRUv3Cell
GRUv3
initialstates
```

## Normalisation & Regularisation
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
# layers
Bilinear, Scale,
# utils
outputsize, state, create_bias, @layer,
outputsize, state, create_bias, @layer, initialstates,
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
# from Train
Expand Down
71 changes: 59 additions & 12 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,39 @@ end

@layer RNNCell

"""
initialstates(rnn) -> AbstractVector
Return the initial hidden state for the given cell or recurrent layer.
The returned vector is initialized to zeros and has the appropriate
dimension inferred from the cell's internal recurrent weight matrix.
# Arguments
- `rnn`: The recurrent neural network cell or recurrent layer for
which the initial state vector is requested. It can be any of
`RNNCell`, `RNN`, `LSTMCell`, `LSTM`, `GRUCell`, `GRU`,
`GRUv3Cell`, and `GRUv3`
# Returns
An `AbstractVector` of zeros representing the initial hidden state, whose length
matches the output dimension of the cell or recurrent layer.
# Example
```julia
using Flux
# Create an RNNCell from input dimension 10 to output dimension 20
rnn = RNNCell(10 => 20)
# Get the initial hidden state
h0 = initialstates(rnn)
# Get some input data
x = rand(Float32, 10)
# Run forward
res = rnn(x, h0)
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

function RNNCell(
Expand All @@ -86,7 +119,7 @@ end

function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
rnn(x, state)
return rnn(x, state)
end

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
Expand Down Expand Up @@ -178,12 +211,17 @@ end

@layer RNN

initialstates(rnn::RNN) = zeros_like(x, size(rnn.cell.Wh, 1))

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
end

(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1)))
function (m::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
Expand Down Expand Up @@ -371,15 +409,20 @@ end

@layer LSTM

function initialstates(lstm::LSTM)
state = zeros_like(x, size(lstm.cell.Wh, 2))
cstate = zeros_like(state)
return state, cstate
end

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
end

function (m::LSTM)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTM)(x::AbstractArray)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTM)(x::AbstractArray, (h, c))
Expand Down Expand Up @@ -547,14 +590,16 @@ end

@layer GRU

initialstates(gru::GRU) = zeros_like(x, size(gru.cell.Wh, 2))

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
end

function (m::GRU)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRU)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRU)(x::AbstractArray, h)
Expand Down Expand Up @@ -692,14 +737,16 @@ end

@layer GRUv3

initialstates(gru::GRUv3) = zeros_like(x, size(gru.cell.Wh, 2))

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
end

function (m::GRUv3)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRUv3)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3)(x::AbstractArray, h)
Expand Down

0 comments on commit a9bc95f

Please sign in to comment.