From 9e1b1bb0ecc223dce6a79bafbbf9cb1d9ab21351 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Thu, 5 Dec 2024 15:18:37 +0100 Subject: [PATCH 1/5] added initialstates --- src/layers/recurrent.jl | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 750141db1b..bee21260a2 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -69,6 +69,8 @@ end @layer RNNCell +initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) + function RNNCell( (in, out)::Pair, σ = tanh; @@ -82,7 +84,10 @@ function RNNCell( return RNNCell(σ, Wi, Wh, b) end -(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1))) +function (rnn::RNNCell)(x::AbstractVecOrMat) + state = initialstates(rnn) + rnn(x, state) +end function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -261,6 +266,10 @@ end @layer LSTMCell +function initialstates(lstm:: LSTMCell) + return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2)) +end + function LSTMCell( (in, out)::Pair; init_kernel = glorot_uniform, @@ -274,10 +283,9 @@ function LSTMCell( return cell end -function (m::LSTMCell)(x::AbstractVecOrMat) - h = zeros_like(x, size(m.Wh, 2)) - c = zeros_like(h) - return m(x, (h, c)) +function (lstm::LSTMCell)(x::AbstractVecOrMat) + state, cstate = initialstates(lstm) + return lstm(x, (state, cstate)) end function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) @@ -447,6 +455,8 @@ end @layer GRUCell +initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2)) + function GRUCell( (in, out)::Pair; init_kernel = glorot_uniform, @@ -459,7 +469,10 @@ function GRUCell( return GRUCell(Wi, Wh, b) end -(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) +function (gru::GRUCell)(x::AbstractVecOrMat) + state = initialstates(gru) + return gru(x, state) +end function (m::GRUCell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -603,6 +616,8 @@ end @layer GRUv3Cell +initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2)) + function GRUv3Cell( (in, out)::Pair; init_kernel = glorot_uniform, @@ -616,7 +631,10 @@ function GRUv3Cell( return GRUv3Cell(Wi, Wh, b, Wh_h̃) end -(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) +function (gru::GRUv3Cell)(x::AbstractVecOrMat) + state = initialstates(gru) + return gru(x, state) +end function (m::GRUv3Cell)(x::AbstractVecOrMat, h) _size_check(m, x, 1 => size(m.Wi, 2)) From a9bc95f13a476982f766a4288c2e9991ee32a6cc Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 7 Dec 2024 21:20:28 +0100 Subject: [PATCH 2/5] added initialstates to recurrent layers, added docstrings --- docs/src/reference/models/layers.md | 1 + src/Flux.jl | 2 +- src/layers/recurrent.jl | 71 ++++++++++++++++++++++++----- 3 files changed, 61 insertions(+), 13 deletions(-) diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index b798a35291..a98b942755 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -117,6 +117,7 @@ GRUCell GRU GRUv3Cell GRUv3 +initialstates ``` ## Normalisation & Regularisation diff --git a/src/Flux.jl b/src/Flux.jl index 272284ef46..8fb2351aa2 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag, # layers Bilinear, Scale, # utils - outputsize, state, create_bias, @layer, + outputsize, state, create_bias, @layer, initialstates, # from OneHotArrays.jl onehot, onehotbatch, onecold, # from Train diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index bee21260a2..34a2d24519 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -69,6 +69,39 @@ end @layer RNNCell +""" + initialstates(rnn) -> AbstractVector + +Return the initial hidden state for the given cell or recurrent layer. +The returned vector is initialized to zeros and has the appropriate +dimension inferred from the cell's internal recurrent weight matrix. + +# Arguments +- `rnn`: The recurrent neural network cell or recurrent layer for + which the initial state vector is requested. It can be any of + `RNNCell`, `RNN`, `LSTMCell`, `LSTM`, `GRUCell`, `GRU`, + `GRUv3Cell`, and `GRUv3` + +# Returns +An `AbstractVector` of zeros representing the initial hidden state, whose length +matches the output dimension of the cell or recurrent layer. + +# Example +```julia +using Flux + +# Create an RNNCell from input dimension 10 to output dimension 20 +rnn = RNNCell(10 => 20) + +# Get the initial hidden state +h0 = initialstates(rnn) + +# Get some input data +x = rand(Float32, 10) + +# Run forward +res = rnn(x, h0) +""" initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) function RNNCell( @@ -86,7 +119,7 @@ end function (rnn::RNNCell)(x::AbstractVecOrMat) state = initialstates(rnn) - rnn(x, state) + return rnn(x, state) end function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) @@ -178,12 +211,17 @@ end @layer RNN +initialstates(rnn::RNN) = zeros_like(x, size(rnn.cell.Wh, 1)) + function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) cell = RNNCell(in => out, σ; cell_kwargs...) return RNN(cell) end -(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1))) +function (m::RNN)(x::AbstractArray) + state = initialstates(rnn) + return rnn(x, state) +end function (m::RNN)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 @@ -371,15 +409,20 @@ end @layer LSTM +function initialstates(lstm::LSTM) + state = zeros_like(x, size(lstm.cell.Wh, 2)) + cstate = zeros_like(state) + return state, cstate +end + function LSTM((in, out)::Pair; cell_kwargs...) cell = LSTMCell(in => out; cell_kwargs...) return LSTM(cell) end -function (m::LSTM)(x::AbstractArray) - h = zeros_like(x, size(m.cell.Wh, 2)) - c = zeros_like(h) - return m(x, (h, c)) +function (lstm::LSTM)(x::AbstractArray) + state, cstate = initialstates(lstm) + return lstm(x, (state, cstate)) end function (m::LSTM)(x::AbstractArray, (h, c)) @@ -547,14 +590,16 @@ end @layer GRU +initialstates(gru::GRU) = zeros_like(x, size(gru.cell.Wh, 2)) + function GRU((in, out)::Pair; cell_kwargs...) cell = GRUCell(in => out; cell_kwargs...) return GRU(cell) end -function (m::GRU)(x::AbstractArray) - h = zeros_like(x, size(m.cell.Wh, 2)) - return m(x, h) +function (gru::GRU)(x::AbstractArray) + state = initialstates(gru) + return gru(x, state) end function (m::GRU)(x::AbstractArray, h) @@ -692,14 +737,16 @@ end @layer GRUv3 +initialstates(gru::GRUv3) = zeros_like(x, size(gru.cell.Wh, 2)) + function GRUv3((in, out)::Pair; cell_kwargs...) cell = GRUv3Cell(in => out; cell_kwargs...) return GRUv3(cell) end -function (m::GRUv3)(x::AbstractArray) - h = zeros_like(x, size(m.cell.Wh, 2)) - return m(x, h) +function (gru::GRUv3)(x::AbstractArray) + state = initialstates(gru) + return gru(x, state) end function (m::GRUv3)(x::AbstractArray, h) From d54dcd40eac32103a293e21be976278b3f569b99 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Sat, 7 Dec 2024 21:37:00 +0100 Subject: [PATCH 3/5] fixed small errors --- src/layers/recurrent.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 34a2d24519..c74cea4daa 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -211,14 +211,14 @@ end @layer RNN -initialstates(rnn::RNN) = zeros_like(x, size(rnn.cell.Wh, 1)) +initialstates(rnn::RNN) = zeros_like(rnn.cell.Wh, size(rnn.cell.Wh, 1)) function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) cell = RNNCell(in => out, σ; cell_kwargs...) return RNN(cell) end -function (m::RNN)(x::AbstractArray) +function (rnn::RNN)(x::AbstractArray) state = initialstates(rnn) return rnn(x, state) end @@ -410,7 +410,7 @@ end @layer LSTM function initialstates(lstm::LSTM) - state = zeros_like(x, size(lstm.cell.Wh, 2)) + state = zeros_like(lstm.cell.Wh, size(lstm.cell.Wh, 2)) cstate = zeros_like(state) return state, cstate end @@ -590,7 +590,7 @@ end @layer GRU -initialstates(gru::GRU) = zeros_like(x, size(gru.cell.Wh, 2)) +initialstates(gru::GRU) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2)) function GRU((in, out)::Pair; cell_kwargs...) cell = GRUCell(in => out; cell_kwargs...) @@ -737,7 +737,7 @@ end @layer GRUv3 -initialstates(gru::GRUv3) = zeros_like(x, size(gru.cell.Wh, 2)) +initialstates(gru::GRUv3) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2)) function GRUv3((in, out)::Pair; cell_kwargs...) cell = GRUv3Cell(in => out; cell_kwargs...) From ab25eeed796f9da6fc845c891d4184b26e0b8e14 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Mon, 9 Dec 2024 22:48:25 +0100 Subject: [PATCH 4/5] streamlined implementation, added tests --- src/layers/recurrent.jl | 62 ++++++++++++++++++++++------------------ test/layers/recurrent.jl | 36 +++++++++++++++++++++-- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index c74cea4daa..6a49b1abae 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -32,7 +32,7 @@ The arguments of the forward pass are: - `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). # Examples @@ -72,19 +72,7 @@ end """ initialstates(rnn) -> AbstractVector -Return the initial hidden state for the given cell or recurrent layer. -The returned vector is initialized to zeros and has the appropriate -dimension inferred from the cell's internal recurrent weight matrix. - -# Arguments -- `rnn`: The recurrent neural network cell or recurrent layer for - which the initial state vector is requested. It can be any of - `RNNCell`, `RNN`, `LSTMCell`, `LSTM`, `GRUCell`, `GRU`, - `GRUv3Cell`, and `GRUv3` - -# Returns -An `AbstractVector` of zeros representing the initial hidden state, whose length -matches the output dimension of the cell or recurrent layer. +Return the initial hidden state for the given recurrent cell or recurrent layer. # Example ```julia @@ -168,7 +156,7 @@ The arguments of the forward pass are: - `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`. - `h`: The initial hidden state of the RNN. If given, it is a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. @@ -211,7 +199,7 @@ end @layer RNN -initialstates(rnn::RNN) = zeros_like(rnn.cell.Wh, size(rnn.cell.Wh, 1)) +initialstates(rnn::RNN) = initialstates(rnn.cell) function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) cell = RNNCell(in => out, σ; cell_kwargs...) @@ -274,7 +262,7 @@ The arguments of the forward pass are: - `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`. - `(h, c)`: A tuple containing the hidden and cell states of the LSTM. They should be vectors of size `out` or matrices of size `out x batch_size`. - If not provided, they are assumed to be vectors of zeros. + If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`. @@ -378,7 +366,7 @@ The arguments of the forward pass are: - `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. - `(h, c)`: A tuple containing the hidden and cell states of the LSTM. They should be vectors of size `out` or matrices of size `out x batch_size`. - If not provided, they are assumed to be vectors of zeros. + If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t` in tensors of size `out x len` or `out x len x batch_size`. @@ -409,11 +397,7 @@ end @layer LSTM -function initialstates(lstm::LSTM) - state = zeros_like(lstm.cell.Wh, size(lstm.cell.Wh, 2)) - cstate = zeros_like(state) - return state, cstate -end +initialstates(lstm::LSTM) = initialstates(lstm.cell) function LSTM((in, out)::Pair; cell_kwargs...) cell = LSTMCell(in => out; cell_kwargs...) @@ -473,7 +457,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences. The arguments of the forward pass are: - `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. @@ -570,7 +554,7 @@ The arguments of the forward pass are: - `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. - `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. @@ -590,7 +574,7 @@ end @layer GRU -initialstates(gru::GRU) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2)) +initialstates(gru::GRU) = initialstates(gru.cell) function GRU((in, out)::Pair; cell_kwargs...) cell = GRUCell(in => out; cell_kwargs...) @@ -648,7 +632,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. The arguments of the forward pass are: - `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`. - `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. - If not provided, it is assumed to be a vector of zeros. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. """ @@ -730,6 +714,28 @@ but only a less popular variant. - `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. - `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. + +# Forward + + gruv3(x, [h]) + +The arguments of the forward pass are: + +- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`. +- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. + If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). + +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. + +# Examples + +```julia +d_in, d_out, len, batch_size = 2, 3, 4, 5 +gruv3 = GRUv3(d_in => d_out) +x = rand(Float32, (d_in, len, batch_size)) +h0 = zeros(Float32, d_out) +h = gruv3(x, h0) # out x len x batch_size +``` """ struct GRUv3{M} cell::M @@ -737,7 +743,7 @@ end @layer GRUv3 -initialstates(gru::GRUv3) = zeros_like(gru.cell.Wh, size(gru.cell.Wh, 2)) +initialstates(gru::GRUv3) = initialstates(gru.cell) function GRUv3((in, out)::Pair; cell_kwargs...) cell = GRUv3Cell(in => out; cell_kwargs...) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index f882cdccc2..f4f4777fd2 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -43,6 +43,9 @@ test_gradients(r, x, h, loss=loss3) # splat test_gradients(r, x, h, loss=loss4) # vcat and stack + # initial states are zero + @test Flux.initialstates(r) ≈ zeros(Float32, 5) + # no initial state same as zero initial state @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) @@ -80,8 +83,11 @@ end @test size(y) == (4, 3, 1) test_gradients(model, x) + rnn = model.rnn + # initial states are zero + @test Flux.initialstates(rnn) ≈ zeros(Float32, 4) + # no initial state same as zero initial state - rnn = model.rnn @test rnn(x) ≈ rnn(x, zeros(Float32, 4)) x = rand(Float32, 2, 3) @@ -120,6 +126,11 @@ end test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) test_gradients(cell, x, (h, c), loss = loss) + # initial states are zero + h0, c0 = Flux.initialstates(cell) + @test h0 ≈ zeros(Float32, 5) + @test c0 ≈ zeros(Float32, 5) + # no initial state same as zero initial state hnew1, cnew1 = cell(x[1]) hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) @@ -166,6 +177,12 @@ end @test size(h) == (4, 3) @test c isa Array{Float32, 2} @test size(c) == (4, 3) + + # initial states are zero + h0, c0 = Flux.initialstates(lstm) + @test h0 ≈ zeros(Float32, 4) + @test c0 ≈ zeros(Float32, 4) + # no initial state same as zero initial state h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4))) @test h ≈ h1 @@ -192,6 +209,9 @@ end h = randn(Float32, 5) test_gradients(r, x, h; loss) + # initial states are zero + @test Flux.initialstates(r) ≈ zeros(Float32, 5) + # no initial state same as zero initial state @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) @@ -227,8 +247,12 @@ end @test size(y) == (4, 3, 1) test_gradients(model, x) - # no initial state same as zero initial state + gru = model.gru + # initial states are zero + @test Flux.initialstates(gru) ≈ zeros(Float32, 4) + + # no initial state same as zero initial state @test gru(x) ≈ gru(x, zeros(Float32, 4)) # No Bias @@ -246,6 +270,9 @@ end h = randn(Float32, 5) test_gradients(r, x, h) + # initial states are zero + @test Flux.initialstates(r) ≈ zeros(Float32, 5) + # no initial state same as zero initial state @test r(x) ≈ r(x, zeros(Float32, 5)) @@ -277,7 +304,10 @@ end @test size(y) == (4, 3, 1) test_gradients(model, x) - # no initial state same as zero initial state gru = model.gru + # initial states are zero + @test Flux.initialstates(gru) ≈ zeros(Float32, 4) + + # no initial state same as zero initial state @test gru(x) ≈ gru(x, zeros(Float32, 4)) end From 48402d7a3ea9ad7de761458684c36959f5f733b4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 9 Dec 2024 23:15:17 +0100 Subject: [PATCH 5/5] Update docs/src/reference/models/layers.md --- docs/src/reference/models/layers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/reference/models/layers.md b/docs/src/reference/models/layers.md index a98b942755..b89a4171fc 100644 --- a/docs/src/reference/models/layers.md +++ b/docs/src/reference/models/layers.md @@ -117,7 +117,7 @@ GRUCell GRU GRUv3Cell GRUv3 -initialstates +Flux.initialstates ``` ## Normalisation & Regularisation