diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index acd279d709..dfa4a442f2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -75,4 +75,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} - DATADEPS_ALWAYS_ACCEPT: true + diff --git a/docs/make.jl b/docs/make.jl index 4367639d8e..d74486936b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,6 +2,7 @@ using Documenter, Flux, NNlib, Functors, MLUtils, BSON, Optimisers, OneHotArrays, Zygote, ChainRulesCore, Plots, MLDatasets, Statistics, DataFrames, JLD2, MLDataDevices +ENV["DATADEPS_ALWAYS_ACCEPT"] = true DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive = true) diff --git a/docs/src/guide/models/basics.md b/docs/src/guide/models/basics.md index 7ad62ee207..7141fb0b28 100644 --- a/docs/src/guide/models/basics.md +++ b/docs/src/guide/models/basics.md @@ -185,7 +185,7 @@ These matching nested structures are at the core of how Flux works.

 Zygote.jl

``` -Flux's [`gradient`](@ref) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl). +Flux's [`gradient`](@ref Flux.gradient) function by default calls a companion packages called [Zygote](https://github.com/FluxML/Zygote.jl). Zygote performs source-to-source automatic differentiation, meaning that `gradient(f, x)` hooks into Julia's compiler to find out what operations `f` contains, and transforms this to produce code for computing `∂f/∂x`. @@ -372,7 +372,7 @@ How does this `model3` differ from the `model1` we had before? Its contents is stored in a tuple, thus `model3.layers[1].weight` is an array. * Flux's layer [`Dense`](@ref Flux.Dense) has only minor differences from our `struct Layer`: - Like `struct Poly3{T}` above, it has type parameters for its fields -- the compiler does not know exactly what type `layer3s.W` will be, which costs speed. - - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref) by default. + - Its initialisation uses not `randn` (normal distribution) but [`glorot_uniform`](@ref Flux.glorot_uniform) by default. - It reshapes some inputs (to allow several batch dimensions), and produces more friendly errors on wrong-size input. - And it has some performance tricks: making sure element types match, and re-using some memory. * The function [`σ`](@ref NNlib.sigmoid) is calculated in a slightly better way, diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index 5b2e70f095..446ff86f82 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -166,3 +166,50 @@ opt_state = Flux.setup(AdamW(1e-3), model) g = gradient(m -> Flux.mse(m(x), y), model)[1] Flux.update!(opt_state, model, g) ``` + +Finally, the [`Recurrence`](@ref) layer can be used wrap any recurrent cell to process the entire sequence at once. For instance, a type behaving the same as the `LSTM` layer can be defined as follows: + +```julia +rnn = Recurrence(LSTMCell(2 => 3)) # similar to LSTM(2 => 3) +x = rand(Float32, 2, 4, 3) +y = rnn(x) +``` + +## Stacking recurrent layers + +Recurrent layers can be stacked to form a deeper model by simply chaining them together using the [`Chain`](@ref) layer. The output of a layer is fed as input to the next layer in the chain. +For instance, a model with two LSTM layers can be defined as follows: + +```julia +stacked_rnn = Chain(LSTM(3 => 5), Dropout(0.5), LSTM(5 => 5)) +x = rand(Float32, 3, 4) +y = stacked_rnn(x) +``` + +If more fine grained control is needed, for instance to have a trainable initial hidden state, one can define a custom model as follows: + +```julia +struct StackedRNN{L,S} + layers::L + states0::S +end + +Flux.@layer StackedRNN + +function StackedRNN(d::Int; num_layers::Int) + layers = [LSTM(d => d) for _ in 1:num_layers] + states0 = [Flux.initialstates(l) for l in layers] + return StackedRNN(layers, states0) +end + +function (m::StackedRNN)(x) + for (layer, state0) in zip(rnn.layers, rnn.states0) + x = layer(x, state0) + end + return x +end + +rnn = StackedRNN(3; num_layers=2) +x = rand(Float32, 3, 10) +y = rnn(x) +``` diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index 355d3e7833..562304de70 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -104,6 +104,7 @@ PairwiseFusion Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data). ```@docs +Recurrence RNNCell RNN LSTMCell diff --git a/src/Flux.jl b/src/Flux.jl index 8fb2351aa2..3a598e88f5 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -38,7 +38,7 @@ using EnzymeCore: EnzymeCore export Chain, Dense, Embedding, EmbeddingBag, Maxout, SkipConnection, Parallel, PairwiseFusion, RNNCell, LSTMCell, GRUCell, GRUv3Cell, - RNN, LSTM, GRU, GRUv3, + RNN, LSTM, GRU, GRUv3, Recurrence, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, Dropout, AlphaDropout, diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 9386a3fc2d..a6a82b0e72 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,8 +1,7 @@ out_from_state(state) = state out_from_state(state::Tuple) = state[1] -function scan(cell, x, state0) - state = state0 +function scan(cell, x, state) y = [] for x_t in eachslice(x, dims = 2) state = cell(x_t, state) @@ -12,7 +11,67 @@ function scan(cell, x, state0) return stack(y, dims = 2) end +""" + Recurrence(cell) + +Create a recurrent layer that processes entire sequences out +of a recurrent `cell`, such as an [`RNNCell`](@ref), [`LSTMCell`](@ref), or [`GRUCell`](@ref), +similarly to how [`RNN`](@ref), [`LSTM`](@ref), and [`GRU`](@ref) process sequences. + +The `cell` should be a callable object that takes an input `x` and a hidden state `state` and returns +a new hidden state `state'`. The `cell` should also implement the `initialstates` method that returns +the initial hidden state. The output of the `cell` is considered to be: +1. The first element of the `state` tuple if `state` is a tuple (e.g. `(h, c)` for LSTM). +2. The `state` itself if `state` is not a tuple, e.g. an array `h` for RNN and GRU. + +# Forward + + rnn(x, [state]) + +The input `x` should be an array of size `in x len` or `in x len x batch_size`, +where `in` is the input dimension of the cell, `len` is the sequence length, and `batch_size` is the batch size. +The `state` should be a valid state for the recurrent cell. If not provided, it obtained by calling +`Flux.initialstates(cell)`. + +The output is an array of size `out x len x batch_size`, where `out` is the output dimension of the cell. + +The operation performed is semantically equivalent to the following code: +```julia +out_from_state(state) = state +out_from_state(state::Tuple) = state[1] +state = Flux.initialstates(cell) +out = [] +for x_t in eachslice(x, dims = 2) + state = cell(x_t, state) + out = [out..., out_from_state(state)] +end +stack(out, dims = 2) +``` + +# Examples + +```jldoctest +julia> rnn = Recurrence(RNNCell(2 => 3)) +Recurrence( + RNNCell(2 => 3, tanh), # 18 parameters +) # Total: 3 arrays, 18 parameters, 232 bytes. + +julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size + +julia> y = rnn(x); # out x len x batch_size +``` +""" +struct Recurrence{M} + cell::M +end + +@layer Recurrence + +initialstates(rnn::Recurrence) = initialstates(rnn.cell) + +(rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn)) +(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state) # Vanilla RNN @doc raw""" @@ -184,9 +243,7 @@ julia> x = rand(Float32, (d_in, len, batch_size)); julia> h = zeros(Float32, (d_out, batch_size)); julia> rnn = RNN(d_in => d_out) -RNN( - RNNCell(4 => 6, tanh), # 66 parameters -) # Total: 3 arrays, 66 parameters, 424 bytes. +RNN(4 => 6, tanh) # 66 parameters julia> y = rnn(x, h); # [y] = [d_out, len, batch_size] ``` @@ -211,7 +268,7 @@ struct RNN{M} cell::M end -@layer RNN +@layer :noexpand RNN initialstates(rnn::RNN) = initialstates(rnn.cell) @@ -232,6 +289,12 @@ function (m::RNN)(x::AbstractArray, h) return scan(m.cell, x, h) end +function Base.show(io::IO, m::RNN) + print(io, "RNN(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1)) + print(io, ", ", m.cell.σ) + print(io, ")") +end + # LSTM @doc raw""" @@ -400,7 +463,7 @@ struct LSTM{M} cell::M end -@layer LSTM +@layer :noexpand LSTM initialstates(lstm::LSTM) = initialstates(lstm.cell) @@ -416,6 +479,10 @@ function (m::LSTM)(x::AbstractArray, state0) return scan(m.cell, x, state0) end +function Base.show(io::IO, m::LSTM) + print(io, "LSTM(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 4, ")") +end + # GRU @doc raw""" @@ -568,7 +635,7 @@ struct GRU{M} cell::M end -@layer GRU +@layer :noexpand GRU initialstates(gru::GRU) = initialstates(gru.cell) @@ -584,6 +651,10 @@ function (m::GRU)(x::AbstractArray, h) return scan(m.cell, x, h) end +function Base.show(io::IO, m::GRU) + print(io, "GRU(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")") +end + # GRU v3 @doc raw""" GRUv3Cell(in => out; init_kernel = glorot_uniform, @@ -728,7 +799,7 @@ struct GRUv3{M} cell::M end -@layer GRUv3 +@layer :noexpand GRUv3 initialstates(gru::GRUv3) = initialstates(gru.cell) @@ -743,3 +814,7 @@ function (m::GRUv3)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 return scan(m.cell, x, h) end + +function Base.show(io::IO, m::GRUv3) + print(io, "GRUv3(", size(m.cell.Wi, 2), " => ", size(m.cell.Wi, 1) ÷ 3, ")") +end \ No newline at end of file diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 864e5dad8e..3ad7428601 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -305,3 +305,12 @@ end # no initial state same as zero initial state @test gru(x) ≈ gru(x, zeros(Float32, 4)) end + +@testset "Recurrence" begin + x = rand(Float32, 2, 3, 4) + for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)] + cell = rnn.cell + rec = Recurrence(cell) + @test rec(x) ≈ rnn(x) + end +end