Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding initialstates function to RNNs #2541

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ GRUCell
GRU
GRUv3Cell
GRUv3
initialstates
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
```

## Normalisation & Regularisation
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
# layers
Bilinear, Scale,
# utils
outputsize, state, create_bias, @layer,
outputsize, state, create_bias, @layer, initialstates,
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
# from Train
Expand Down
101 changes: 83 additions & 18 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,41 @@ end

@layer RNNCell

"""
initialstates(rnn) -> AbstractVector

Return the initial hidden state for the given cell or recurrent layer.
The returned vector is initialized to zeros and has the appropriate
dimension inferred from the cell's internal recurrent weight matrix.

# Arguments
- `rnn`: The recurrent neural network cell or recurrent layer for
which the initial state vector is requested. It can be any of
`RNNCell`, `RNN`, `LSTMCell`, `LSTM`, `GRUCell`, `GRU`,
`GRUv3Cell`, and `GRUv3`

# Returns
An `AbstractVector` of zeros representing the initial hidden state, whose length
matches the output dimension of the cell or recurrent layer.
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

# Example
```julia
using Flux

# Create an RNNCell from input dimension 10 to output dimension 20
rnn = RNNCell(10 => 20)

# Get the initial hidden state
h0 = initialstates(rnn)

# Get some input data
x = rand(Float32, 10)

# Run forward
res = rnn(x, h0)
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

function RNNCell(
(in, out)::Pair,
σ = tanh;
Expand All @@ -82,7 +117,10 @@ function RNNCell(
return RNNCell(σ, Wi, Wh, b)
end

(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -173,12 +211,17 @@ end

@layer RNN

initialstates(rnn::RNN) = zeros_like(rnn.cell.Wh, size(rnn.cell.Wh, 1))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
end

(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1)))
function (rnn::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
Expand Down Expand Up @@ -261,6 +304,10 @@ end

@layer LSTMCell

function initialstates(lstm:: LSTMCell)
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -274,10 +321,9 @@ function LSTMCell(
return cell
end

function (m::LSTMCell)(x::AbstractVecOrMat)
h = zeros_like(x, size(m.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
Expand Down Expand Up @@ -363,15 +409,20 @@ end

@layer LSTM

function initialstates(lstm::LSTM)
state = zeros_like(lstm.cell.Wh, size(lstm.cell.Wh, 2))
cstate = zeros_like(state)
return state, cstate
end

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
end

function (m::LSTM)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTM)(x::AbstractArray)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTM)(x::AbstractArray, (h, c))
Expand Down Expand Up @@ -447,6 +498,8 @@ end

@layer GRUCell

initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -459,7 +512,10 @@ function GRUCell(
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUCell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUCell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -534,14 +590,16 @@ end

@layer GRU

initialstates(gru::GRU) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2))

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
end

function (m::GRU)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRU)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRU)(x::AbstractArray, h)
Expand Down Expand Up @@ -603,6 +661,8 @@ end

@layer GRUv3Cell

initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUv3Cell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -616,7 +676,10 @@ function GRUv3Cell(
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -674,14 +737,16 @@ end

@layer GRUv3

initialstates(gru::GRUv3) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2))

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
end

function (m::GRUv3)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRUv3)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3)(x::AbstractArray, h)
Expand Down
Loading