From ba83bc1e20cb5b3f703845ceebceb96856a384ef Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 12 Dec 2024 18:06:17 +0100 Subject: [PATCH] gpu --- docs/src/guide/models/recurrence.md | 23 +++++++++++------------ test/ext_common/recurrent_gpu_ad.jl | 18 ++++++++++-------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/docs/src/guide/models/recurrence.md b/docs/src/guide/models/recurrence.md index 5b2e70f095..96fd507e8b 100644 --- a/docs/src/guide/models/recurrence.md +++ b/docs/src/guide/models/recurrence.md @@ -21,7 +21,7 @@ b = zeros(Float32, output_size) function rnn_cell(x, h) h = tanh.(Wxh * x .+ Whh * h .+ b) - return h + return h, h end seq_len = 3 @@ -33,14 +33,14 @@ h0 = zeros(Float32, output_size) y = [] ht = h0 for xt in x - ht = rnn_cell(xt, ht) - y = [y; [ht]] # concatenate in non-mutating (AD friendly) way + yt, ht = rnn_cell(xt, ht) + y = [y; [yt]] # concatenate in non-mutating (AD friendly) way end ``` Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`. - -The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model. +The result of the forward pass at each time step, is a tuple contening the output `yt` and the updated state `ht`. The updated state is used as an input in next iteration. In the simple case of a vanilla RNN, the +output and the state are the same. In more complex cells, such as `LSTMCell`, the state can contain multiple arrays. There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with: @@ -58,8 +58,8 @@ rnn_cell = Flux.RNNCell(input_size => output_size) y = [] ht = h0 for xt in x - ht = rnn_cell(xt, ht) - y = [y; [ht]] + yt, ht = rnn_cell(xt, ht) + y = [y; [yt]] end ``` The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression. @@ -78,7 +78,7 @@ struct RecurrentCellModel{H,C,D} end # we choose to not train the initial hidden state -Flux.@layer RecurrentCellModel trainable=(cell,dense) +Flux.@layer RecurrentCellModel trainable=(cell, dense) function RecurrentCellModel(input_size::Int, hidden_size::Int) return RecurrentCellModel( @@ -91,8 +91,8 @@ function (m::RecurrentCellModel)(x) z = [] ht = m.h0 for xt in x - ht = m.cell(xt, ht) - z = [z; [ht]] + yt, ht = m.cell(xt, ht) + z = [z; [yt]] end z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len] ŷ = m.dense(z) # [1, seq_len, batch_size] or [1, seq_len] @@ -109,7 +109,6 @@ using Optimisers: AdamW function loss(model, x, y) ŷ = model(x) - y = stack(y, dims=2) return Flux.mse(ŷ, y) end @@ -123,7 +122,7 @@ model = RecurrentCellModel(input_size, 5) opt_state = Flux.setup(AdamW(1e-3), model) # compute the gradient and update the model -g = gradient(m -> loss(m, x, y),model)[1] +g = gradient(m -> loss(m, x, y), model)[1] Flux.update!(opt_state, model, g) ``` diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index 3046f1f1fa..59f40242d9 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -1,11 +1,9 @@ -out_from_state(state::Tuple) = state[1] -out_from_state(state) = state +cell_loss(cell, x, state) = mean(cell(x, state)[1]) function recurrent_cell_loss(cell, seq, state) out = [] for xt in seq - state = cell(xt, state) - yt = out_from_state(state) + yt, state = cell(xt, state) out = vcat(out, [yt]) end return mean(stack(out, dims = 2)) @@ -18,7 +16,8 @@ end h = zeros(Float32, d_out) # Single Step @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS + compare_finite_diff=false, + loss=cell_loss) broken = :rnncell_single ∈ BROKEN_TESTS # Multiple Steps @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, @@ -51,7 +50,7 @@ end c = zeros(Float32, d_out) # Single Step @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false, - loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS + loss = cell_loss) broken = :lstmcell_single ∈ BROKEN_TESTS # Multiple Steps @test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff = false, @@ -84,7 +83,9 @@ end r = GRUCell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) - @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS + @test test_gradients(r, x[1], h; test_gpu=true, + compare_finite_diff=false, + loss = cell_loss) broken = :grucell_single ∈ BROKEN_TESTS @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff = false, loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS @@ -116,7 +117,8 @@ end x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) @test test_gradients(r, x[1], h; test_gpu=true, - compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS + compare_finite_diff=false + loss=cell_loss) broken = :gruv3cell_single ∈ BROKEN_TESTS @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS