Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding initialstates function to RNNs #2541

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/models/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ GRUCell
GRU
GRUv3Cell
GRUv3
Flux.initialstates
```

## Normalisation & Regularisation
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export Chain, Dense, Embedding, EmbeddingBag,
# layers
Bilinear, Scale,
# utils
outputsize, state, create_bias, @layer,
outputsize, state, create_bias, @layer, initialstates,
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
# from Train
Expand Down
121 changes: 96 additions & 25 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The arguments of the forward pass are:

- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

# Examples

Expand Down Expand Up @@ -69,6 +69,29 @@ end

@layer RNNCell

"""
initialstates(rnn) -> AbstractVector

Return the initial hidden state for the given recurrent cell or recurrent layer.

# Example
```julia
using Flux

# Create an RNNCell from input dimension 10 to output dimension 20
rnn = RNNCell(10 => 20)

# Get the initial hidden state
h0 = initialstates(rnn)

# Get some input data
x = rand(Float32, 10)

# Run forward
res = rnn(x, h0)
"""
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

function RNNCell(
(in, out)::Pair,
σ = tanh;
Expand All @@ -82,7 +105,10 @@ function RNNCell(
return RNNCell(σ, Wi, Wh, b)
end

(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -130,7 +156,7 @@ The arguments of the forward pass are:
- `x`: The input to the RNN. It should be a matrix size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the RNN.
If given, it is a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.

Expand Down Expand Up @@ -173,12 +199,17 @@ end

@layer RNN

initialstates(rnn::RNN) = initialstates(rnn.cell)

function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
cell = RNNCell(in => out, σ; cell_kwargs...)
return RNN(cell)
end

(m::RNN)(x::AbstractArray) = m(x, zeros_like(x, size(m.cell.Wh, 1)))
function (rnn::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end

function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
Expand Down Expand Up @@ -231,7 +262,7 @@ The arguments of the forward pass are:
- `x`: The input to the LSTM. It should be a matrix of size `in` or an array of size `in x batch_size`.
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).

Returns a tuple `(h′, c′)` containing the new hidden state and cell state in tensors of size `out` or `out x batch_size`.

Expand Down Expand Up @@ -261,6 +292,10 @@ end

@layer LSTMCell

function initialstates(lstm:: LSTMCell)
return zeros_like(lstm.Wh, size(lstm.Wh, 2)), zeros_like(lstm.Wh, size(lstm.Wh, 2))
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -274,10 +309,9 @@ function LSTMCell(
return cell
end

function (m::LSTMCell)(x::AbstractVecOrMat)
h = zeros_like(x, size(m.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
Expand Down Expand Up @@ -332,7 +366,7 @@ The arguments of the forward pass are:
- `x`: The input to the LSTM. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `(h, c)`: A tuple containing the hidden and cell states of the LSTM.
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).

Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
in tensors of size `out x len` or `out x len x batch_size`.
Expand Down Expand Up @@ -363,15 +397,16 @@ end

@layer LSTM

initialstates(lstm::LSTM) = initialstates(lstm.cell)

function LSTM((in, out)::Pair; cell_kwargs...)
cell = LSTMCell(in => out; cell_kwargs...)
return LSTM(cell)
end

function (m::LSTM)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
function (lstm::LSTM)(x::AbstractArray)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end

function (m::LSTM)(x::AbstractArray, (h, c))
Expand Down Expand Up @@ -422,7 +457,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences.
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.

Expand All @@ -447,6 +482,8 @@ end

@layer GRUCell

initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -459,7 +496,10 @@ function GRUCell(
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUCell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUCell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -514,7 +554,7 @@ The arguments of the forward pass are:

- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.

Expand All @@ -534,14 +574,16 @@ end

@layer GRU

initialstates(gru::GRU) = initialstates(gru.cell)

function GRU((in, out)::Pair; cell_kwargs...)
cell = GRUCell(in => out; cell_kwargs...)
return GRU(cell)
end

function (m::GRU)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRU)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRU)(x::AbstractArray, h)
Expand Down Expand Up @@ -590,7 +632,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.
The arguments of the forward pass are:
- `x`: The input to the GRU. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`.
"""
Expand All @@ -603,6 +645,8 @@ end

@layer GRUv3Cell

initialstates(gru::GRUv3Cell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUv3Cell(
(in, out)::Pair;
init_kernel = glorot_uniform,
Expand All @@ -616,7 +660,10 @@ function GRUv3Cell(
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
function (gru::GRUv3Cell)(x::AbstractVecOrMat)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -667,21 +714,45 @@ but only a less popular variant.
- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`.
- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`.
- `bias`: Whether to include a bias term initialized to zero. Default is `true`.

# Forward

gruv3(x, [h])

The arguments of the forward pass are:

- `x`: The input to the GRU. It should be a matrix of size `in x len` or an array of size `in x len x batch_size`.
- `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).

Returns all new hidden states `h_t` as an array of size `out x len x batch_size`.

# Examples

```julia
d_in, d_out, len, batch_size = 2, 3, 4, 5
gruv3 = GRUv3(d_in => d_out)
x = rand(Float32, (d_in, len, batch_size))
h0 = zeros(Float32, d_out)
h = gruv3(x, h0) # out x len x batch_size
```
"""
struct GRUv3{M}
cell::M
end

@layer GRUv3

initialstates(gru::GRUv3) = initialstates(gru.cell)

function GRUv3((in, out)::Pair; cell_kwargs...)
cell = GRUv3Cell(in => out; cell_kwargs...)
return GRUv3(cell)
end

function (m::GRUv3)(x::AbstractArray)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
function (gru::GRUv3)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end

function (m::GRUv3)(x::AbstractArray, h)
Expand Down
36 changes: 33 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
test_gradients(r, x, h, loss=loss3) # splat
test_gradients(r, x, h, loss=loss4) # vcat and stack

# initial states are zero
@test Flux.initialstates(r) ≈ zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x[1]) ≈ r(x[1], zeros(Float32, 5))

Expand Down Expand Up @@ -80,8 +83,11 @@ end
@test size(y) == (4, 3, 1)
test_gradients(model, x)

rnn = model.rnn
# initial states are zero
@test Flux.initialstates(rnn) ≈ zeros(Float32, 4)

# no initial state same as zero initial state
rnn = model.rnn
@test rnn(x) ≈ rnn(x, zeros(Float32, 4))

x = rand(Float32, 2, 3)
Expand Down Expand Up @@ -120,6 +126,11 @@ end
test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1]))
test_gradients(cell, x, (h, c), loss = loss)

# initial states are zero
h0, c0 = Flux.initialstates(cell)
@test h0 ≈ zeros(Float32, 5)
@test c0 ≈ zeros(Float32, 5)

# no initial state same as zero initial state
hnew1, cnew1 = cell(x[1])
hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5)))
Expand Down Expand Up @@ -166,6 +177,12 @@ end
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)

# initial states are zero
h0, c0 = Flux.initialstates(lstm)
@test h0 ≈ zeros(Float32, 4)
@test c0 ≈ zeros(Float32, 4)

# no initial state same as zero initial state
h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
@test h ≈ h1
Expand All @@ -192,6 +209,9 @@ end
h = randn(Float32, 5)
test_gradients(r, x, h; loss)

# initial states are zero
@test Flux.initialstates(r) ≈ zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x[1]) ≈ r(x[1], zeros(Float32, 5))

Expand Down Expand Up @@ -227,8 +247,12 @@ end
@test size(y) == (4, 3, 1)
test_gradients(model, x)

# no initial state same as zero initial state

gru = model.gru
# initial states are zero
@test Flux.initialstates(gru) ≈ zeros(Float32, 4)

# no initial state same as zero initial state
@test gru(x) ≈ gru(x, zeros(Float32, 4))

# No Bias
Expand All @@ -246,6 +270,9 @@ end
h = randn(Float32, 5)
test_gradients(r, x, h)

# initial states are zero
@test Flux.initialstates(r) ≈ zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x) ≈ r(x, zeros(Float32, 5))

Expand Down Expand Up @@ -277,7 +304,10 @@ end
@test size(y) == (4, 3, 1)
test_gradients(model, x)

# no initial state same as zero initial state
gru = model.gru
# initial states are zero
@test Flux.initialstates(gru) ≈ zeros(Float32, 4)

# no initial state same as zero initial state
@test gru(x) ≈ gru(x, zeros(Float32, 4))
end
Loading