Skip to content

Commit

Permalink
Recurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 12, 2024
1 parent f96bd58 commit 4ef5d11
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
Expand All @@ -12,7 +11,47 @@ function scan(cell, x, state0)
return stack(y, dims = 2)
end

"""
Recurrent(cell)
Create a recurrent layer that processes entire sequences out
of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref),
similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences.
The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns
a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns
the initial hidden state. The output of the `cell` is considered to be:
1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM).
2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU.
# Forward
rnn(x, [state])
The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`,
where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size.
The operation performed is semantically equivalent to the following code:
```julia
state = Flux.initialstates(cell)
out = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = [out..., get_output(state)]
end
stack(out, dims = 2)
```
"""
struct Recurrent{M}
cell::M
end

@layer Recurrent

initialstates(rnn::Recurrent) = initialstates(rnn.cell)

(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn))
(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state)

# Vanilla RNN
@doc raw"""
Expand Down Expand Up @@ -87,16 +126,15 @@ end
initialstates(rnn) -> AbstractVector
Return the initial hidden state for the given recurrent cell or recurrent layer.
Should be implemented for all recurrent cells and layers.
# Example
```julia
using Flux
# Create an RNNCell from input dimension 10 to output dimension 20
rnn = RNNCell(10 => 20)
# Get the initial hidden state
h0 = initialstates(rnn)
h0 = Flux.initialstates(rnn)
# Get some input data
x = rand(Float32, 10)
Expand All @@ -107,22 +145,20 @@ res = rnn(x, h0)
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

function RNNCell(
(in, out)::Pair,
σ = tanh;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair,
σ = tanh;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)

Wi = init_kernel(out, in)
Wh = init_recurrent_kernel(out, out)
b = create_bias(Wi, bias, size(Wi, 1))
return RNNCell(σ, Wi, Wh, b)
end

function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(x, state)
end
(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn))

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -300,7 +336,7 @@ end

@layer LSTMCell

function initialstates(lstm:: LSTMCell)
function initialstates(lstm::LSTMCell)
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
end

Expand Down

0 comments on commit 4ef5d11

Please sign in to comment.