diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 9386a3fc2d..3affdc2973 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,8 +1,7 @@ out_from_state(state) = state out_from_state(state::Tuple) = state[1] -function scan(cell, x, state0) - state = state0 +function scan(cell, x, state) y = [] for x_t in eachslice(x, dims = 2) state = cell(x_t, state) @@ -12,7 +11,47 @@ function scan(cell, x, state0) return stack(y, dims = 2) end +""" + Recurrent(cell) + +Create a recurrent layer that processes entire sequences out +of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref), +similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences. + +The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns +a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns +the initial hidden state. The output of the `cell` is considered to be: +1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM). +2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU. + +# Forward + + rnn(x, [state]) + +The input `x` should be a matrix of size `in x len` or an array of size `in x len x batch_size`, +where `in` is the input dimension, `len` is the sequence length, and `batch_size` is the batch size. + +The operation performed is semantically equivalent to the following code: +```julia +state = Flux.initialstates(cell) +out = [] +for x_t in eachslice(x, dims = 2) + state = cell(x_t, state) + out = [out..., get_output(state)] +end +stack(out, dims = 2) +``` +""" +struct Recurrent{M} + cell::M +end + +@layer Recurrent +initialstates(rnn::Recurrent) = initialstates(rnn.cell) + +(rnn::Recurrent)(x::AbstractArray) = rnn(x, initialstates(rnn)) +(rnn::Recurrent)(x::AbstractArray, state) = scan(rnn.cell, x, state) # Vanilla RNN @doc raw""" @@ -87,16 +126,15 @@ end initialstates(rnn) -> AbstractVector Return the initial hidden state for the given recurrent cell or recurrent layer. +Should be implemented for all recurrent cells and layers. # Example ```julia -using Flux - # Create an RNNCell from input dimension 10 to output dimension 20 rnn = RNNCell(10 => 20) # Get the initial hidden state -h0 = initialstates(rnn) +h0 = Flux.initialstates(rnn) # Get some input data x = rand(Float32, 10) @@ -107,22 +145,20 @@ res = rnn(x, h0) initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) function RNNCell( - (in, out)::Pair, - σ = tanh; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true, -) + (in, out)::Pair, + σ = tanh; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, + ) + Wi = init_kernel(out, in) Wh = init_recurrent_kernel(out, out) b = create_bias(Wi, bias, size(Wi, 1)) return RNNCell(σ, Wi, Wh, b) end -function (rnn::RNNCell)(x::AbstractVecOrMat) - state = initialstates(rnn) - return rnn(x, state) -end +(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn)) function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -300,7 +336,7 @@ end @layer LSTMCell -function initialstates(lstm:: LSTMCell) +function initialstates(lstm::LSTMCell) return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2)) end