Skip to content

Commit

Permalink
gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 12, 2024
1 parent 5da71b8 commit ba83bc1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
23 changes: 11 additions & 12 deletions docs/src/guide/models/recurrence.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ b = zeros(Float32, output_size)

function rnn_cell(x, h)
h = tanh.(Wxh * x .+ Whh * h .+ b)
return h
return h, h
end

seq_len = 3
Expand All @@ -33,14 +33,14 @@ h0 = zeros(Float32, output_size)
y = []
ht = h0
for xt in x
ht = rnn_cell(xt, ht)
y = [y; [ht]] # concatenate in non-mutating (AD friendly) way
yt, ht = rnn_cell(xt, ht)
y = [y; [yt]] # concatenate in non-mutating (AD friendly) way
end
```

Notice how the above is essentially a `Dense` layer that acts on two inputs, `xt` and `ht`.

The output at each time step, called the hidden state, is used as the input to the next time step and is also the output of the model.
The result of the forward pass at each time step, is a tuple contening the output `yt` and the updated state `ht`. The updated state is used as an input in next iteration. In the simple case of a vanilla RNN, the
output and the state are the same. In more complex cells, such as `LSTMCell`, the state can contain multiple arrays.

There are various recurrent cells available in Flux, notably `RNNCell`, `LSTMCell` and `GRUCell`, which are documented in the [layer reference](../../reference/models/layers.md). The hand-written example above can be replaced with:

Expand All @@ -58,8 +58,8 @@ rnn_cell = Flux.RNNCell(input_size => output_size)
y = []
ht = h0
for xt in x
ht = rnn_cell(xt, ht)
y = [y; [ht]]
yt, ht = rnn_cell(xt, ht)
y = [y; [yt]]
end
```
The entire output `y` or just the last output `y[end]` can be used for further processing, such as classification or regression.
Expand All @@ -78,7 +78,7 @@ struct RecurrentCellModel{H,C,D}
end

# we choose to not train the initial hidden state
Flux.@layer RecurrentCellModel trainable=(cell,dense)
Flux.@layer RecurrentCellModel trainable=(cell, dense)

function RecurrentCellModel(input_size::Int, hidden_size::Int)
return RecurrentCellModel(
Expand All @@ -91,8 +91,8 @@ function (m::RecurrentCellModel)(x)
z = []
ht = m.h0
for xt in x
ht = m.cell(xt, ht)
z = [z; [ht]]
yt, ht = m.cell(xt, ht)
z = [z; [yt]]
end
z = stack(z, dims=2) # [hidden_size, seq_len, batch_size] or [hidden_size, seq_len]
= m.dense(z) # [1, seq_len, batch_size] or [1, seq_len]
Expand All @@ -109,7 +109,6 @@ using Optimisers: AdamW

function loss(model, x, y)
= model(x)
y = stack(y, dims=2)
return Flux.mse(ŷ, y)
end

Expand All @@ -123,7 +122,7 @@ model = RecurrentCellModel(input_size, 5)
opt_state = Flux.setup(AdamW(1e-3), model)

# compute the gradient and update the model
g = gradient(m -> loss(m, x, y),model)[1]
g = gradient(m -> loss(m, x, y), model)[1]
Flux.update!(opt_state, model, g)
```

Expand Down
18 changes: 10 additions & 8 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
out_from_state(state::Tuple) = state[1]
out_from_state(state) = state
cell_loss(cell, x, state) = mean(cell(x, state)[1])

function recurrent_cell_loss(cell, seq, state)
out = []
for xt in seq
state = cell(xt, state)
yt = out_from_state(state)
yt, state = cell(xt, state)
out = vcat(out, [yt])
end
return mean(stack(out, dims = 2))
Expand All @@ -18,7 +16,8 @@ end
h = zeros(Float32, d_out)
# Single Step
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
compare_finite_diff=false,
loss=cell_loss) broken = :rnncell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
Expand Down Expand Up @@ -51,7 +50,7 @@ end
c = zeros(Float32, d_out)
# Single Step
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single BROKEN_TESTS
loss = cell_loss) broken = :lstmcell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(cell, x, (h, c); test_gpu=true,
compare_finite_diff = false,
Expand Down Expand Up @@ -84,7 +83,9 @@ end
r = GRUCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single BROKEN_TESTS
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false,
loss = cell_loss) broken = :grucell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff = false,
loss = recurrent_cell_loss) broken = :grucell_multiple BROKEN_TESTS
Expand Down Expand Up @@ -116,7 +117,8 @@ end
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
compare_finite_diff=false
loss=cell_loss) broken = :gruv3cell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
loss = recurrent_cell_loss) broken = :gruv3cell_multiple BROKEN_TESTS
Expand Down

0 comments on commit ba83bc1

Please sign in to comment.