Skip to content

Commit

Permalink
Adding initialstates function to RNNs (#2541)
Browse files Browse the repository at this point in the history
* added initialstates

* added initialstates to recurrent layers, added docstrings

* fixed small errors

* streamlined implementation, added tests

* Update docs/src/reference/models/layers.md

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
MartinuzziFrancesco and CarloLucibello authored Dec 10, 2024
1 parent 8c60006 commit 40b7f70
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ GRUCell
GRU
GRUv3Cell
GRUv3
Flux.initialstates
```

## Normalisation & Regularisation
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
# layers
Bilinear, Scale,
# utils
outputsize, state, create_bias, @layer,
outputsize, state, create_bias, @layer, initialstates,
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
# from Train
Expand Down
121 changes: 96 additions & 25 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The arguments of the forward pass are:
- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
# Examples
Expand Down Expand Up @@ -69,6 +69,29 @@ end

@layer RNNCell

"""
initialstates(rnn) -> AbstractVector
Return the initial hidden state for the given recurrent cell or recurrent layer.
# Example
```julia
using Flux
# Create an RNNCell from input dimension 10 to output dimension 20
rnn = RNNCell(10 => 20)
# Get the initial hidden state
h0 = initialstates(rnn)
# Get some input data
x = rand(Float32, 10)
# Run forward
res = rnn(x, h0)
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

function RNNCell(
(in, out)::Pair,
σ = tanh;
Expand All @@ -82,7 +105,10 @@ function RNNCell(
return RNNCell(σ, Wi, Wh, b)
end

(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -130,7 +156,7 @@ The arguments of the forward pass are:
- `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the RNN.
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Expand Down Expand Up @@ -173,12 +199,17 @@ end

@layer RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
end

(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1)))
function (rnn::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
Expand Down Expand Up @@ -231,7 +262,7 @@ The arguments of the forward pass are:
- `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`.
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.
Expand Down Expand Up @@ -261,6 +292,10 @@ end

@layer LSTMCell

function initialstates(lstm:: LSTMCell)
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -274,10 +309,9 @@ function LSTMCell(
return cell
end

function (m::LSTMCell)(x::AbstractVecOrMat)
h = zeros_like(x, size(m.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
Expand Down Expand Up @@ -332,7 +366,7 @@ The arguments of the forward pass are:
- `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
in tensors of size `out x len` or `out x len x batch_size`.
Expand Down Expand Up @@ -363,15 +397,16 @@ end

@layer LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
end

function (m::LSTM)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTM)(x::AbstractArray)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTM)(x::AbstractArray, (h, c))
Expand Down Expand Up @@ -422,7 +457,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences.
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
Expand All @@ -447,6 +482,8 @@ end

@layer GRUCell

initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -459,7 +496,10 @@ function GRUCell(
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUCell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUCell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -514,7 +554,7 @@ The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
Expand All @@ -534,14 +574,16 @@ end

@layer GRU

initialstates(gru::GRU) = initialstates(gru.cell)

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
end

function (m::GRU)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRU)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRU)(x::AbstractArray, h)
Expand Down Expand Up @@ -590,7 +632,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
"""
Expand All @@ -603,6 +645,8 @@ end

@layer GRUv3Cell

initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUv3Cell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -616,7 +660,10 @@ function GRUv3Cell(
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -667,21 +714,45 @@ but only a less popular variant.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.
# Forward
gruv3(x, [h])
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.
# Examples
```julia
d_in, d_out, len, batch_size = 2, 3, 4, 5
gruv3 = GRUv3(d_in => d_out)
x = rand(Float32, (d_in, len, batch_size))
h0 = zeros(Float32, d_out)
h = gruv3(x, h0) # out x len x batch_size
```
"""
struct GRUv3{M}
cell::M
end

@layer GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
end

function (m::GRUv3)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRUv3)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3)(x::AbstractArray, h)
Expand Down
36 changes: 33 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
test_gradients(r, x, h, loss=loss3) # splat
test_gradients(r, x, h, loss=loss4) # vcat and stack

# initial states are zero
@test Flux.initialstates(r) zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x[1]) r(x[1], zeros(Float32, 5))

Expand Down Expand Up @@ -80,8 +83,11 @@ end
@test size(y) == (4, 3, 1)
test_gradients(model, x)

rnn = model.rnn
# initial states are zero
@test Flux.initialstates(rnn) zeros(Float32, 4)

# no initial state same as zero initial state
rnn = model.rnn
@test rnn(x) rnn(x, zeros(Float32, 4))

x = rand(Float32, 2, 3)
Expand Down Expand Up @@ -120,6 +126,11 @@ end
test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1]))
test_gradients(cell, x, (h, c), loss = loss)

# initial states are zero
h0, c0 = Flux.initialstates(cell)
@test h0 zeros(Float32, 5)
@test c0 zeros(Float32, 5)

# no initial state same as zero initial state
hnew1, cnew1 = cell(x[1])
hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5)))
Expand Down Expand Up @@ -166,6 +177,12 @@ end
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)

# initial states are zero
h0, c0 = Flux.initialstates(lstm)
@test h0 zeros(Float32, 4)
@test c0 zeros(Float32, 4)

# no initial state same as zero initial state
h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
@test h h1
Expand All @@ -192,6 +209,9 @@ end
h = randn(Float32, 5)
test_gradients(r, x, h; loss)

# initial states are zero
@test Flux.initialstates(r) zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x[1]) r(x[1], zeros(Float32, 5))

Expand Down Expand Up @@ -227,8 +247,12 @@ end
@test size(y) == (4, 3, 1)
test_gradients(model, x)

# no initial state same as zero initial state

gru = model.gru
# initial states are zero
@test Flux.initialstates(gru) zeros(Float32, 4)

# no initial state same as zero initial state
@test gru(x) gru(x, zeros(Float32, 4))

# No Bias
Expand All @@ -246,6 +270,9 @@ end
h = randn(Float32, 5)
test_gradients(r, x, h)

# initial states are zero
@test Flux.initialstates(r) zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x) r(x, zeros(Float32, 5))

Expand Down Expand Up @@ -277,7 +304,10 @@ end
@test size(y) == (4, 3, 1)
test_gradients(model, x)

# no initial state same as zero initial state
gru = model.gru
# initial states are zero
@test Flux.initialstates(gru) zeros(Float32, 4)

# no initial state same as zero initial state
@test gru(x) gru(x, zeros(Float32, 4))
end

0 comments on commit 40b7f70

Please sign in to comment.