From 40b7f700b65acf8fda0f51d6b17e2dfec5f7787f Mon Sep 17 00:00:00 2001 From: Francesco Martinuzzi Date: Tue, 10 Dec 2024 07:06:36 +0100 Subject: [PATCH] Adding initialstates function to RNNs (#2541) * added initialstates * added initialstates to recurrent layers, added docstrings * fixed small errors * streamlined implementation, added tests * Update docs/src/reference/models/layers.md --------- Co-authored-by: Carlo Lucibello --- docs/src/reference/models/layers.md | 1 + src/Flux.jl | 2 +- src/layers/recurrent.jl | 121 ++++++++++++++++++++++------ test/layers/recurrent.jl | 36 ++++++++- 4 files changed, 131 insertions(+), 29 deletions(-) diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index 8e5c0e873c..355d3e7833 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -112,6 +112,7 @@ GRUCell GRU GRUv3Cell GRUv3 +Flux.initialstates ``` ## Normalisation & Regularisation diff --git a/src/Flux.jl b/src/Flux.jl index 272284ef46..8fb2351aa2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag, # layers Bilinear, Scale, # utils - outputsize, state, create_bias, @layer, + outputsize, state, create_bias, @layer, initialstates, # from OneHotArrays.jl onehot, onehotbatch, onecold, # from Train diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index dd5f84aa58..a170bb2d3d 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -32,7 +32,7 @@ The arguments of the forward pass are: - `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). # Examples @@ -69,6 +69,29 @@ end @layer RNNCell +""" + initialstates(rnn) -> AbstractVector + +Return the initial hidden state for the given recurrent cell or recurrent layer. + +# Example +```julia +using Flux + +# Create an RNNCell from input dimension 10 to output dimension 20 +rnn = RNNCell(10 => 20) + +# Get the initial hidden state +h0 = initialstates(rnn) + +# Get some input data +x = rand(Float32, 10) + +# Run forward +res = rnn(x, h0) +""" +initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) + function RNNCell( (in, out)::Pair, σ = tanh; @@ -82,7 +105,10 @@ function RNNCell( return RNNCell(σ, Wi, Wh, b) end -(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1))) +function (rnn::RNNCell)(x::AbstractVecOrMat) + state = initialstates(rnn) + return rnn(x, state) +end function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -130,7 +156,7 @@ The arguments of the forward pass are: - `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`. - `h`: The initial hidden state of the RNN. If given, it is a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. @@ -173,12 +199,17 @@ end @layer RNN +initialstates(rnn::RNN) = initialstates(rnn.cell) + function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) cell = RNNCell(in => out, σ; cell_kwargs...) return RNN(cell) end -(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1))) +function (rnn::RNN)(x::AbstractArray) + state = initialstates(rnn) + return rnn(x, state) +end function (m::RNN)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 @@ -231,7 +262,7 @@ The arguments of the forward pass are: - `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`. - `(h, c)`: A tuple containing the hidden and cell states of the LSTM. They should be vectors of size `out` or matrices of size `out x batch_size`. - If not provided, they are assumed to be vectors of zeros. + If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`. @@ -261,6 +292,10 @@ end @layer LSTMCell +function initialstates(lstm:: LSTMCell) + return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2)) +end + function LSTMCell( (in, out)::Pair; init_kernel = glorot_uniform, @@ -274,10 +309,9 @@ function LSTMCell( return cell end -function (m::LSTMCell)(x::AbstractVecOrMat) - h = zeros_like(x, size(m.Wh, 2)) - c = zeros_like(h) - return m(x, (h, c)) +function (lstm::LSTMCell)(x::AbstractVecOrMat) + state, cstate = initialstates(lstm) + return lstm(x, (state, cstate)) end function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) @@ -332,7 +366,7 @@ The arguments of the forward pass are: - `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. - `(h, c)`: A tuple containing the hidden and cell states of the LSTM. They should be vectors of size `out` or matrices of size `out x batch_size`. - If not provided, they are assumed to be vectors of zeros. + If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t` in tensors of size `out x len` or `out x len x batch_size`. @@ -363,15 +397,16 @@ end @layer LSTM +initialstates(lstm::LSTM) = initialstates(lstm.cell) + function LSTM((in, out)::Pair; cell_kwargs...) cell = LSTMCell(in => out; cell_kwargs...) return LSTM(cell) end -function (m::LSTM)(x::AbstractArray) - h = zeros_like(x, size(m.cell.Wh, 2)) - c = zeros_like(h) - return m(x, (h, c)) +function (lstm::LSTM)(x::AbstractArray) + state, cstate = initialstates(lstm) + return lstm(x, (state, cstate)) end function (m::LSTM)(x::AbstractArray, (h, c)) @@ -422,7 +457,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences. The arguments of the forward pass are: - `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. @@ -447,6 +482,8 @@ end @layer GRUCell +initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2)) + function GRUCell( (in, out)::Pair; init_kernel = glorot_uniform, @@ -459,7 +496,10 @@ function GRUCell( return GRUCell(Wi, Wh, b) end -(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) +function (gru::GRUCell)(x::AbstractVecOrMat) + state = initialstates(gru) + return gru(x, state) +end function (m::GRUCell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -514,7 +554,7 @@ The arguments of the forward pass are: - `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. - `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. @@ -534,14 +574,16 @@ end @layer GRU +initialstates(gru::GRU) = initialstates(gru.cell) + function GRU((in, out)::Pair; cell_kwargs...) cell = GRUCell(in => out; cell_kwargs...) return GRU(cell) end -function (m::GRU)(x::AbstractArray) - h = zeros_like(x, size(m.cell.Wh, 2)) - return m(x, h) +function (gru::GRU)(x::AbstractArray) + state = initialstates(gru) + return gru(x, state) end function (m::GRU)(x::AbstractArray, h) @@ -590,7 +632,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. The arguments of the forward pass are: - `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. """ @@ -603,6 +645,8 @@ end @layer GRUv3Cell +initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2)) + function GRUv3Cell( (in, out)::Pair; init_kernel = glorot_uniform, @@ -616,7 +660,10 @@ function GRUv3Cell( return GRUv3Cell(Wi, Wh, b, Wh_h̃) end -(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) +function (gru::GRUv3Cell)(x::AbstractVecOrMat) + state = initialstates(gru) + return gru(x, state) +end function (m::GRUv3Cell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -667,6 +714,28 @@ but only a less popular variant. - `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. - `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. + +# Forward + + gruv3(x, [h]) + +The arguments of the forward pass are: + +- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. +- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). + +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. + +# Examples + +```julia +d_in, d_out, len, batch_size = 2, 3, 4, 5 +gruv3 = GRUv3(d_in => d_out) +x = rand(Float32, (d_in, len, batch_size)) +h0 = zeros(Float32, d_out) +h = gruv3(x, h0) # out x len x batch_size +``` """ struct GRUv3{M} cell::M @@ -674,14 +743,16 @@ end @layer GRUv3 +initialstates(gru::GRUv3) = initialstates(gru.cell) + function GRUv3((in, out)::Pair; cell_kwargs...) cell = GRUv3Cell(in => out; cell_kwargs...) return GRUv3(cell) end -function (m::GRUv3)(x::AbstractArray) - h = zeros_like(x, size(m.cell.Wh, 2)) - return m(x, h) +function (gru::GRUv3)(x::AbstractArray) + state = initialstates(gru) + return gru(x, state) end function (m::GRUv3)(x::AbstractArray, h) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index f882cdccc2..f4f4777fd2 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -43,6 +43,9 @@ test_gradients(r, x, h, loss=loss3) # splat test_gradients(r, x, h, loss=loss4) # vcat and stack + # initial states are zero + @test Flux.initialstates(r) ≈ zeros(Float32, 5) + # no initial state same as zero initial state @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) @@ -80,8 +83,11 @@ end @test size(y) == (4, 3, 1) test_gradients(model, x) + rnn = model.rnn + # initial states are zero + @test Flux.initialstates(rnn) ≈ zeros(Float32, 4) + # no initial state same as zero initial state - rnn = model.rnn @test rnn(x) ≈ rnn(x, zeros(Float32, 4)) x = rand(Float32, 2, 3) @@ -120,6 +126,11 @@ end test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) test_gradients(cell, x, (h, c), loss = loss) + # initial states are zero + h0, c0 = Flux.initialstates(cell) + @test h0 ≈ zeros(Float32, 5) + @test c0 ≈ zeros(Float32, 5) + # no initial state same as zero initial state hnew1, cnew1 = cell(x[1]) hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) @@ -166,6 +177,12 @@ end @test size(h) == (4, 3) @test c isa Array{Float32, 2} @test size(c) == (4, 3) + + # initial states are zero + h0, c0 = Flux.initialstates(lstm) + @test h0 ≈ zeros(Float32, 4) + @test c0 ≈ zeros(Float32, 4) + # no initial state same as zero initial state h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4))) @test h ≈ h1 @@ -192,6 +209,9 @@ end h = randn(Float32, 5) test_gradients(r, x, h; loss) + # initial states are zero + @test Flux.initialstates(r) ≈ zeros(Float32, 5) + # no initial state same as zero initial state @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) @@ -227,8 +247,12 @@ end @test size(y) == (4, 3, 1) test_gradients(model, x) - # no initial state same as zero initial state + gru = model.gru + # initial states are zero + @test Flux.initialstates(gru) ≈ zeros(Float32, 4) + + # no initial state same as zero initial state @test gru(x) ≈ gru(x, zeros(Float32, 4)) # No Bias @@ -246,6 +270,9 @@ end h = randn(Float32, 5) test_gradients(r, x, h) + # initial states are zero + @test Flux.initialstates(r) ≈ zeros(Float32, 5) + # no initial state same as zero initial state @test r(x) ≈ r(x, zeros(Float32, 5)) @@ -277,7 +304,10 @@ end @test size(y) == (4, 3, 1) test_gradients(model, x) - # no initial state same as zero initial state gru = model.gru + # initial states are zero + @test Flux.initialstates(gru) ≈ zeros(Float32, 4) + + # no initial state same as zero initial state @test gru(x) ≈ gru(x, zeros(Float32, 4)) end