Skip to content

Commit

Permalink
hotfix LSTM ouput (#2547)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Dec 11, 2024
1 parent 428be48 commit 130af41
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 134 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ CUDA = "5"
ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.13"
Functors = "0.5"
EnzymeCore = "0.7.7, 0.8.4"
Functors = "0.5"
MLDataDevices = "1.4.2"
MLUtils = "0.4"
MPI = "0.20.19"
Expand Down
99 changes: 39 additions & 60 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = out_from_state(state)
y = vcat(y, [out])
end
return stack(y, dims = 2)
end


# Vanilla RNN

# Vanilla RNN
@doc raw"""
RNNCell(in => out, σ = tanh; init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform, bias = true)
Expand Down Expand Up @@ -215,13 +229,7 @@ function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
# [x] = [in, L] or [in, L, B]
# [h] = [out] or [out, B]
y = []
for x_t in eachslice(x, dims = 2)
h = m.cell(x_t, h)
# y = [y..., h]
y = vcat(y, [h])
end
return stack(y, dims = 2)
return scan(m.cell, x, h)
end


Expand Down Expand Up @@ -297,22 +305,20 @@ function initialstates(lstm:: LSTMCell)
end

function LSTMCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)

Wi = init_kernel(out * 4, in)
Wh = init_recurrent_kernel(out * 4, out)
b = create_bias(Wi, bias, out * 4)
cell = LSTMCell(Wi, Wh, b)
return cell
end

function (lstm::LSTMCell)(x::AbstractVecOrMat)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end
(lstm::LSTMCell)(x::AbstractVecOrMat) = lstm(x, initialstates(lstm))

function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
_size_check(m, x, 1 => size(m.Wi, 2))
Expand Down Expand Up @@ -368,15 +374,14 @@ The arguments of the forward pass are:
They should be vectors of size `out` or matrices of size `out x batch_size`.
If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t`
in tensors of size `out x len` or `out x len x batch_size`.
Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`.
# Examples
```julia
struct Model
lstm::LSTM
h0::AbstractVector
h0::AbstractVector # trainable initial hidden state
c0::AbstractVector
end
Expand All @@ -387,7 +392,7 @@ Flux.@layer Model
d_in, d_out, len, batch_size = 2, 3, 4, 5
x = rand(Float32, (d_in, len, batch_size))
model = Model(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
h, c = model(x)
h = model(x)
size(h) # out x len x batch_size
```
"""
Expand All @@ -404,21 +409,11 @@ function LSTM((in, out)::Pair; cell_kwargs...)
return LSTM(cell)
end

function (lstm::LSTM)(x::AbstractArray)
state, cstate = initialstates(lstm)
return lstm(x, (state, cstate))
end
(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm))

function (m::LSTM)(x::AbstractArray, (h, c))
function (m::LSTM)(x::AbstractArray, state0)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
c′ = []
for x_t in eachslice(x, dims = 2)
h, c = m.cell(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = vcat(c′, [c])
end
return stack(h′, dims = 2), stack(c′, dims = 2)
return scan(m.cell, x, state0)
end

# GRU
Expand Down Expand Up @@ -485,11 +480,12 @@ end
initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2))

function GRUCell(
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)

Wi = init_kernel(out * 3, in)
Wh = init_recurrent_kernel(out * 3, out)
b = create_bias(Wi, bias, size(Wi, 1))
Expand Down Expand Up @@ -581,20 +577,11 @@ function GRU((in, out)::Pair; cell_kwargs...)
return GRU(cell)
end

function (gru::GRU)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end
(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRU)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
# [x] = [in, L] or [in, L, B]
for x_t in eachslice(x, dims = 2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
end
return stack(h′, dims = 2)
return scan(m.cell, x, h)
end

# GRU v3
Expand Down Expand Up @@ -750,17 +737,9 @@ function GRUv3((in, out)::Pair; cell_kwargs...)
return GRUv3(cell)
end

function (gru::GRUv3)(x::AbstractArray)
state = initialstates(gru)
return gru(x, state)
end
(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru))

function (m::GRUv3)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
h′ = []
for x_t in eachslice(x, dims = 2)
h = m.cell(x_t, h)
h′ = vcat(h′, [h])
end
return stack(h′, dims = 2)
return scan(m.cell, x, h)
end
100 changes: 40 additions & 60 deletions test/ext_common/recurrent_gpu_ad.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@

@testset "RNNCell GPU AD" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
# return mean(h)
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
out_from_state(state::Tuple) = state[1]
out_from_state(state) = state

function recurrent_cell_loss(cell, seq, state)
out = []
for xt in seq
state = cell(xt, state)
yt = out_from_state(state)
out = vcat(out, [yt])
end
return mean(stack(out, dims = 2))
end

@testset "RNNCell GPU AD" begin
d_in, d_out, len, batch_size = 2, 3, 4, 5
r = RNNCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
# Single Step
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :rnncell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
loss=recurrent_cell_loss) broken = :rnncell_multiple BROKEN_TESTS
end

@testset "RNN GPU AD" begin
Expand All @@ -40,21 +44,6 @@ end
end

@testset "LSTMCell" begin

function loss(r, x, hc)
h, c = hc
h′ = []
c′ = []
for x_t in x
h, c = r(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = [c′..., c]
end
hnew = stack(h′, dims=2)
cnew = stack(c′, dims=2)
return mean(hnew) + mean(cnew)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
cell = LSTMCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
Expand All @@ -64,7 +53,9 @@ end
@test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false,
loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single BROKEN_TESTS
# Multiple Steps
@test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple BROKEN_TESTS
@test test_gradients(cell, x, (h, c); test_gpu=true,
compare_finite_diff = false,
loss = recurrent_cell_loss) broken = :lstmcell_multiple BROKEN_TESTS
end

@testset "LSTM" begin
Expand All @@ -81,30 +72,22 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false,
loss = (m, x) -> mean(m(x)[1])) broken = :lstm_nobatch BROKEN_TESTS
@test test_gradients(model, x_nobatch; test_gpu=true,
compare_finite_diff=false) broken = :lstm_nobatch BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false,
loss = (m, x) -> mean(m(x)[1])) broken = :lstm BROKEN_TESTS
@test test_gradients(model, x; test_gpu=true,
compare_finite_diff=false) broken = :lstm BROKEN_TESTS
end

@testset "GRUCell" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = GRUCell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff = false,
loss = recurrent_cell_loss) broken = :grucell_multiple BROKEN_TESTS
end

@testset "GRU GPU AD" begin
Expand All @@ -120,28 +103,23 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gru_nobatch BROKEN_TESTS
@test test_gradients(model, x_nobatch; test_gpu=true,
compare_finite_diff=false) broken = :gru_nobatch BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gru BROKEN_TESTS
@test test_gradients(model, x; test_gpu=true,
compare_finite_diff=false) broken = :gru BROKEN_TESTS
end

@testset "GRUv3Cell GPU AD" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y)
end

d_in, d_out, len, batch_size = 2, 3, 4, 5
r = GRUv3Cell(d_in => d_out)
x = [randn(Float32, d_in, batch_size) for _ in 1:len]
h = zeros(Float32, d_out)
@test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple BROKEN_TESTS
@test test_gradients(r, x[1], h; test_gpu=true,
compare_finite_diff=false) broken = :gruv3cell_single BROKEN_TESTS
@test test_gradients(r, x, h; test_gpu=true,
compare_finite_diff=false,
loss = recurrent_cell_loss) broken = :gruv3cell_multiple BROKEN_TESTS
end

@testset "GRUv3 GPU AD" begin
Expand All @@ -157,7 +135,9 @@ end
d_in, d_out, len, batch_size = 2, 3, 4, 5
model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out))
x_nobatch = randn(Float32, d_in, len)
@test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gruv3_nobatch BROKEN_TESTS
@test test_gradients(model, x_nobatch; test_gpu=true,
compare_finite_diff=false) broken = :gruv3_nobatch BROKEN_TESTS
x = randn(Float32, d_in, len, batch_size)
@test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gruv3 BROKEN_TESTS
@test test_gradients(model, x; test_gpu=true,
compare_finite_diff=false) broken = :gruv3 BROKEN_TESTS
end
20 changes: 7 additions & 13 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,37 +156,31 @@ end
model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4))

x = rand(Float32, 2, 3, 1)
h, c = model(x)
h = model(x)
@test h isa Array{Float32, 3}
@test size(h) == (4, 3, 1)
@test c isa Array{Float32, 3}
@test size(c) == (4, 3, 1)
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))
test_gradients(model, x)

x = rand(Float32, 2, 3)
h, c = model(x)
h = model(x)
@test h isa Array{Float32, 2}
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)
test_gradients(model, x, loss = (m, x) -> mean(m(x)[1]))

# test default initial states
lstm = model.lstm
h, c = lstm(x)
h = lstm(x)
@test h isa Array{Float32, 2}
@test size(h) == (4, 3)
@test c isa Array{Float32, 2}
@test size(c) == (4, 3)


# initial states are zero
h0, c0 = Flux.initialstates(lstm)
@test h0 zeros(Float32, 4)
@test c0 zeros(Float32, 4)

# no initial state same as zero initial state
h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
h1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4)))
@test h h1
@test c c1
end

@testset "GRUCell" begin
Expand Down

0 comments on commit 130af41

Please sign in to comment.