Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

invert h and x arguments in RNNs #2521

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl

## v0.15.0
* Recurrent layers have undergone a complete redesign in [PR 2500](https://github.com/FluxML/Flux.jl/pull/2500).
* `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing: `rnncell(x_t, h_t) -> h_{t+1}`.
* `RNNCell`, `LSTMCell`, and `GRUCell` are now exported and provide functionality for single time-step processing: `rnncell(h_t, x_t) -> h_{t+1}`.
* `RNN`, `LSTM`, and `GRU` no longer store the hidden state internally, it has to be explicitely passed to the layer. Moreover, they now process entire sequences at once, rather than one element at a time: `rnn(x, h) -> h′`.
* The `Recur` wrapper has been deprecated and removed.
* The `reset!` function has also been removed; state management is now entirely up to the user.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Wxh = randn(Float32, output_size, input_size)
Whh = randn(Float32, output_size, output_size)
b = zeros(Float32, output_size)

function rnn_cell(x, h)
function rnn_cell(h, x)
h = tanh.(Wxh * x .+ Whh * h .+ b)
return h
end
Expand Down
68 changes: 34 additions & 34 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ See [`RNN`](@ref) for a layer that processes entire sequences.

# Forward

rnncell(x, [h])
rnncell([h,] x)

The arguments of the forward pass are:

- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros.
- `x`: The input to the RNN. It should be a vector of size `in` or a matrix of size `in x batch_size`.

# Examples

Expand All @@ -48,7 +48,7 @@ h = zeros(Float32, 5)
ŷ = []

for x_t in x
h = r(x_t, h)
h = r(h, x_t)
ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation
# is not automatic differentiation friendly yet.
# Can use `y = vcat(y, [h])` as an alternative.
Expand All @@ -74,9 +74,9 @@ function RNNCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
return RNNCell(σ, Wi, Wh, b)
end

(m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1)))
(m::RNNCell)(x::AbstractVecOrMat) = m(zeros_like(x, size(m.Wh, 1)), x)

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
function (m::RNNCell)(h::AbstractVecOrMat, x::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi,2))
σ = NNlib.fast_act(m.σ, x)
h = σ.(m.Wi*x .+ m.Wh*h .+ m.bias)
Expand Down Expand Up @@ -113,7 +113,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step.

# Forward

rnn(x, h)
rnn(h, x)

The arguments of the forward pass are:

Expand All @@ -136,7 +136,7 @@ RNN(
RNNCell(4 => 6, tanh), # 66 parameters
) # Total: 3 arrays, 66 parameters, 424 bytes.

julia> y = rnn(x, h); # [y] = [d_out, len, batch_size]
julia> y = rnn(h, x); # [y] = [d_out, len, batch_size]
```

Sometimes, the initial hidden state is a learnable parameter.
Expand All @@ -150,7 +150,7 @@ end

Flux.@layer :expand Model

(m::Model)(x) = m.rnn(x, m.h0)
(m::Model)(x) = m.rnn(m.h0, x)

model = Model(RNN(32 => 64), zeros(Float32, 64))
```
Expand All @@ -166,15 +166,15 @@ function RNN((in, out)::Pair, σ = tanh; bias = true, init = glorot_uniform)
return RNN(cell)
end

(m::RNN)(x) = m(x, zeros_like(x, size(m.cell.Wh, 1)))
(m::RNN)(x) = m(zeros_like(x, size(m.cell.Wh, 1)), x)

function (m::RNN)(x, h)
function (m::RNN)(h, x)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
y = []
for x_t in eachslice(x, dims=2)
h = m.cell(x_t, h)
h = m.cell(h, x_t)
# y = [y..., h]
y = vcat(y, [h])
end
Expand Down Expand Up @@ -210,7 +210,7 @@ See also [`LSTM`](@ref) for a layer that processes entire sequences.

# Forward

lstmcell(x, (h, c))
lstmcell((h, c), x)
lstmcell(x)

The arguments of the forward pass are:
Expand All @@ -233,7 +233,7 @@ julia> c = zeros(Float32, 5); # cell state

julia> x = rand(Float32, 3, 4); # in x batch_size

julia> h′, c′ = l(x, (h, c));
julia> h′, c′ = l((h, c), x);

julia> size(h′) # out x batch_size
(5, 4)
Expand All @@ -258,10 +258,10 @@ end
function (m::LSTMCell)(x::AbstractVecOrMat)
h = zeros_like(x, size(m.Wh, 2))
c = zeros_like(h)
return m(x, (h, c))
return m((h, c), x)
end

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
function (m::LSTMCell)((h, c), x::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
b = m.bias
g = m.Wi * x .+ m.Wh * h .+ b
Expand Down Expand Up @@ -304,7 +304,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step.

# Forward

lstm(x, (h, c))
lstm((h, c), x)
lstm(x)

The arguments of the forward pass are:
Expand All @@ -327,7 +327,7 @@ end

Flux.@layer :expand Model

(m::Model)(x) = m.lstm(x, (m.h0, m.c0))
(m::Model)(x) = m.lstm((m.h0, m.c0), x)

d_in, d_out, len, batch_size = 2, 3, 4, 5
x = rand(Float32, (d_in, len, batch_size))
Expand All @@ -350,15 +350,15 @@ end
function (m::LSTM)(x)
h = zeros_like(x, size(m.cell.Wh, 1))
c = zeros_like(h)
return m(x, (h, c))
return m((h, c), x)
end

function (m::LSTM)(x, (h, c))
function (m::LSTM)((h, c), x)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
c′ = []
for x_t in eachslice(x, dims=2)
h, c = m.cell(x_t, (h, c))
h, c = m.cell((h, c), x_t)
h′ = vcat(h′, [h])
c′ = vcat(c′, [c])
end
Expand Down Expand Up @@ -393,7 +393,7 @@ See also [`GRU`](@ref) for a layer that processes entire sequences.

# Forward

grucell(x, h)
grucell(h, x)
grucell(x)

The arguments of the forward pass are:
Expand All @@ -413,7 +413,7 @@ julia> h = zeros(Float32, 5); # hidden state

julia> x = rand(Float32, 3, 4); # in x batch_size

julia> h′ = g(x, h);
julia> h′ = g(h, x);
```
"""
struct GRUCell{I,H,V}
Expand All @@ -431,9 +431,9 @@ function GRUCell((in, out)::Pair; init = glorot_uniform, bias = true)
return GRUCell(Wi, Wh, b)
end

(m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
(m::GRUCell)(x::AbstractVecOrMat) = m(zeros_like(x, size(m.Wh, 2)), x)

function (m::GRUCell)(x::AbstractVecOrMat, h)
function (m::GRUCell)(h, x::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi,2))
gxs = chunk(m.Wi * x, 3, dims=1)
ghs = chunk(m.Wh * h, 3, dims=1)
Expand Down Expand Up @@ -472,7 +472,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step.

# Forward

gru(x, h)
gru(h, x)
gru(x)

The arguments of the forward pass are:
Expand Down Expand Up @@ -506,15 +506,15 @@ end

function (m::GRU)(x)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
return m(h, x)
end

function (m::GRU)(x, h)
function (m::GRU)(h, x)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
# [x] = [in, L] or [in, L, B]
for x_t in eachslice(x, dims=2)
h = m.cell(x_t, h)
h = m.cell(h, x_t)
h′ = vcat(h′, [h])
end
return stack(h′, dims=2)
Expand Down Expand Up @@ -548,7 +548,7 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer.

# Forward

gruv3cell(x, h)
gruv3cell(h, x)
gruv3cell(x)

The arguments of the forward pass are:
Expand All @@ -575,9 +575,9 @@ function GRUv3Cell((in, out)::Pair; init = glorot_uniform, bias = true)
return GRUv3Cell(Wi, Wh, b, Wh_h̃)
end

(m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2)))
(m::GRUv3Cell)(x::AbstractVecOrMat) = m(zeros_like(x, size(m.Wh, 2)), x)

function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
function (m::GRUv3Cell)(h, x::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi,2))
gxs = chunk(m.Wi * x, 3, dims=1)
ghs = chunk(m.Wh * h, 3, dims=1)
Expand Down Expand Up @@ -629,14 +629,14 @@ end

function (m::GRUv3)(x)
h = zeros_like(x, size(m.cell.Wh, 2))
return m(x, h)
return m(h, x)
end

function (m::GRUv3)(x, h)
function (m::GRUv3)(h, x)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
for x_t in eachslice(x, dims=2)
h = m.cell(x_t, h)
h = m.cell(h, x_t)
h′ = vcat(h′, [h])
end
return stack(h′, dims=2)
Expand Down
32 changes: 16 additions & 16 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

@testset "RNNCell GPU AD" begin
function loss(r, x, h)
function loss(r, h, x)
y = []
for x_t in x
h = r(x_t, h)
h = r(h, x_t)
y = vcat(y, [h])
end
# return mean(h)
Expand All @@ -18,7 +18,7 @@
# Single Step
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS
# Multiple Steps
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
@test test_gradients(r, h, x; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS
end

@testset "RNN GPU AD" begin
Expand All @@ -29,7 +29,7 @@ end

Flux.@layer :expand ModelRNN

(m::ModelRNN)(x) = m.rnn(x, m.h0)
(m::ModelRNN)(x) = m.rnn(m.h0, x)

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelRNN(RNN(d_in => d_out), zeros(Float32, d_out))
Expand All @@ -41,12 +41,12 @@ end

@testset "LSTMCell" begin

function loss(r, x, hc)
function loss(r, hc, x)
h, c = hc
h′ = []
c′ = []
for x_t in x
h, c = r(x_t, (h, c))
h, c = r((h, c), x_t)
h′ = vcat(h′, [h])
c′ = [c′..., c]
end
Expand All @@ -62,9 +62,9 @@ end
c = zeros(Float32, d_out)
# Single Step
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS
loss = (m, (h, c), x) -> mean(m((h, c), x)[1])) broken = :lstmcell_single ∈ BROKEN_TESTS
# Multiple Steps
@test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
@test test_gradients(cell, (h, c), x; test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS
end

@testset "LSTM" begin
Expand All @@ -89,10 +89,10 @@ end
end

@testset "GRUCell" begin
function loss(r, x, h)
function loss(r, h, x)
y = []
for x_t in x
h = r(x_t, h)
h = r(h, x_t)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
Expand All @@ -104,7 +104,7 @@ end
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS
@test test_gradients(r, h, x; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS
end

@testset "GRU GPU AD" begin
Expand All @@ -115,7 +115,7 @@ end

Flux.@layer :expand ModelGRU

(m::ModelGRU)(x) = m.gru(x, m.h0)
(m::ModelGRU)(x) = m.gru(m.h0, x)

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out))
Expand All @@ -126,10 +126,10 @@ end
end

@testset "GRUv3Cell GPU AD" begin
function loss(r, x, h)
function loss(r, h, x)
y = []
for x_t in x
h = r(x_t, h)
h = r(h, x_t)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
Expand All @@ -141,7 +141,7 @@ end
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
@test test_gradients(r, h, x; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS
end

@testset "GRUv3 GPU AD" begin
Expand All @@ -152,7 +152,7 @@ end

Flux.@layer :expand ModelGRUv3

(m::ModelGRUv3)(x) = m.gru(x, m.h0)
(m::ModelGRUv3)(x) = m.gru(m.h0, x)

d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out))
Expand Down
Loading
Loading