From 130af41ffbc088b4f5797ac1ccd0f85b81a13f38 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 11 Dec 2024 15:46:52 +0100 Subject: [PATCH] hotfix LSTM ouput (#2547) --- Project.toml | 2 +- src/layers/recurrent.jl | 99 +++++++++++---------------- test/ext_common/recurrent_gpu_ad.jl | 100 +++++++++++----------------- test/layers/recurrent.jl | 20 ++---- 4 files changed, 87 insertions(+), 134 deletions(-) diff --git a/Project.toml b/Project.toml index 20e133811f..ade4595b71 100644 --- a/Project.toml +++ b/Project.toml @@ -48,8 +48,8 @@ CUDA = "5" ChainRulesCore = "1.12" Compat = "4.10.0" Enzyme = "0.13" -Functors = "0.5" EnzymeCore = "0.7.7, 0.8.4" +Functors = "0.5" MLDataDevices = "1.4.2" MLUtils = "0.4" MPI = "0.20.19" diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index a170bb2d3d..9386a3fc2d 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,6 +1,20 @@ +out_from_state(state) = state +out_from_state(state::Tuple) = state[1] + +function scan(cell, x, state0) + state = state0 + y = [] + for x_t in eachslice(x, dims = 2) + state = cell(x_t, state) + out = out_from_state(state) + y = vcat(y, [out]) + end + return stack(y, dims = 2) +end + -# Vanilla RNN +# Vanilla RNN @doc raw""" RNNCell(in => out, σ = tanh; init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) @@ -215,13 +229,7 @@ function (m::RNN)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 # [x] = [in, L] or [in, L, B] # [h] = [out] or [out, B] - y = [] - for x_t in eachslice(x, dims = 2) - h = m.cell(x_t, h) - # y = [y..., h] - y = vcat(y, [h]) - end - return stack(y, dims = 2) + return scan(m.cell, x, h) end @@ -297,11 +305,12 @@ function initialstates(lstm:: LSTMCell) end function LSTMCell( - (in, out)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true, -) + (in, out)::Pair; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, + ) + Wi = init_kernel(out * 4, in) Wh = init_recurrent_kernel(out * 4, out) b = create_bias(Wi, bias, out * 4) @@ -309,10 +318,7 @@ function LSTMCell( return cell end -function (lstm::LSTMCell)(x::AbstractVecOrMat) - state, cstate = initialstates(lstm) - return lstm(x, (state, cstate)) -end +(lstm::LSTMCell)(x::AbstractVecOrMat) = lstm(x, initialstates(lstm)) function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) _size_check(m, x, 1 => size(m.Wi, 2)) @@ -368,15 +374,14 @@ The arguments of the forward pass are: They should be vectors of size `out` or matrices of size `out x batch_size`. If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). -Returns a tuple `(h′, c′)` containing all new hidden states `h_t` and cell states `c_t` -in tensors of size `out x len` or `out x len x batch_size`. +Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. # Examples ```julia struct Model lstm::LSTM - h0::AbstractVector + h0::AbstractVector # trainable initial hidden state c0::AbstractVector end @@ -387,7 +392,7 @@ Flux.@layer Model d_in, d_out, len, batch_size = 2, 3, 4, 5 x = rand(Float32, (d_in, len, batch_size)) model = Model(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out)) -h, c = model(x) +h = model(x) size(h) # out x len x batch_size ``` """ @@ -404,21 +409,11 @@ function LSTM((in, out)::Pair; cell_kwargs...) return LSTM(cell) end -function (lstm::LSTM)(x::AbstractArray) - state, cstate = initialstates(lstm) - return lstm(x, (state, cstate)) -end +(lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm)) -function (m::LSTM)(x::AbstractArray, (h, c)) +function (m::LSTM)(x::AbstractArray, state0) @assert ndims(x) == 2 || ndims(x) == 3 - h′ = [] - c′ = [] - for x_t in eachslice(x, dims = 2) - h, c = m.cell(x_t, (h, c)) - h′ = vcat(h′, [h]) - c′ = vcat(c′, [c]) - end - return stack(h′, dims = 2), stack(c′, dims = 2) + return scan(m.cell, x, state0) end # GRU @@ -485,11 +480,12 @@ end initialstates(gru::GRUCell) = zeros_like(gru.Wh, size(gru.Wh, 2)) function GRUCell( - (in, out)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true, -) + (in, out)::Pair; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, + ) + Wi = init_kernel(out * 3, in) Wh = init_recurrent_kernel(out * 3, out) b = create_bias(Wi, bias, size(Wi, 1)) @@ -581,20 +577,11 @@ function GRU((in, out)::Pair; cell_kwargs...) return GRU(cell) end -function (gru::GRU)(x::AbstractArray) - state = initialstates(gru) - return gru(x, state) -end +(gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru)) function (m::GRU)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 - h′ = [] - # [x] = [in, L] or [in, L, B] - for x_t in eachslice(x, dims = 2) - h = m.cell(x_t, h) - h′ = vcat(h′, [h]) - end - return stack(h′, dims = 2) + return scan(m.cell, x, h) end # GRU v3 @@ -750,17 +737,9 @@ function GRUv3((in, out)::Pair; cell_kwargs...) return GRUv3(cell) end -function (gru::GRUv3)(x::AbstractArray) - state = initialstates(gru) - return gru(x, state) -end +(gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru)) function (m::GRUv3)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 - h′ = [] - for x_t in eachslice(x, dims = 2) - h = m.cell(x_t, h) - h′ = vcat(h′, [h]) - end - return stack(h′, dims = 2) + return scan(m.cell, x, h) end diff --git a/test/ext_common/recurrent_gpu_ad.jl b/test/ext_common/recurrent_gpu_ad.jl index 704f147f60..3046f1f1fa 100644 --- a/test/ext_common/recurrent_gpu_ad.jl +++ b/test/ext_common/recurrent_gpu_ad.jl @@ -1,24 +1,28 @@ - -@testset "RNNCell GPU AD" begin - function loss(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - # return mean(h) - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y) +out_from_state(state::Tuple) = state[1] +out_from_state(state) = state + +function recurrent_cell_loss(cell, seq, state) + out = [] + for xt in seq + state = cell(xt, state) + yt = out_from_state(state) + out = vcat(out, [yt]) end + return mean(stack(out, dims = 2)) +end +@testset "RNNCell GPU AD" begin d_in, d_out, len, batch_size = 2, 3, 4, 5 r = RNNCell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) # Single Step - @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS + @test test_gradients(r, x[1], h; test_gpu=true, + compare_finite_diff=false) broken = :rnncell_single ∈ BROKEN_TESTS # Multiple Steps - @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :rnncell_multiple ∈ BROKEN_TESTS + @test test_gradients(r, x, h; test_gpu=true, + compare_finite_diff=false, + loss=recurrent_cell_loss) broken = :rnncell_multiple ∈ BROKEN_TESTS end @testset "RNN GPU AD" begin @@ -40,21 +44,6 @@ end end @testset "LSTMCell" begin - - function loss(r, x, hc) - h, c = hc - h′ = [] - c′ = [] - for x_t in x - h, c = r(x_t, (h, c)) - h′ = vcat(h′, [h]) - c′ = [c′..., c] - end - hnew = stack(h′, dims=2) - cnew = stack(c′, dims=2) - return mean(hnew) + mean(cnew) - end - d_in, d_out, len, batch_size = 2, 3, 4, 5 cell = LSTMCell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] @@ -64,7 +53,9 @@ end @test test_gradients(cell, x[1], (h, c); test_gpu=true, compare_finite_diff=false, loss = (m, x, (h, c)) -> mean(m(x, (h, c))[1])) broken = :lstmcell_single ∈ BROKEN_TESTS # Multiple Steps - @test test_gradients(cell, x, (h, c); test_gpu=true, compare_finite_diff=false, loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS + @test test_gradients(cell, x, (h, c); test_gpu=true, + compare_finite_diff = false, + loss = recurrent_cell_loss) broken = :lstmcell_multiple ∈ BROKEN_TESTS end @testset "LSTM" begin @@ -81,30 +72,22 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelLSTM(LSTM(d_in => d_out), zeros(Float32, d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false, - loss = (m, x) -> mean(m(x)[1])) broken = :lstm_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, + compare_finite_diff=false) broken = :lstm_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, len, batch_size) - @test test_gradients(model, x; test_gpu=true, compare_finite_diff=false, - loss = (m, x) -> mean(m(x)[1])) broken = :lstm ∈ BROKEN_TESTS + @test test_gradients(model, x; test_gpu=true, + compare_finite_diff=false) broken = :lstm ∈ BROKEN_TESTS end @testset "GRUCell" begin - function loss(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y) - end - d_in, d_out, len, batch_size = 2, 3, 4, 5 r = GRUCell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :grucell_single ∈ BROKEN_TESTS - @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :grucell_multiple ∈ BROKEN_TESTS + @test test_gradients(r, x, h; test_gpu=true, + compare_finite_diff = false, + loss = recurrent_cell_loss) broken = :grucell_multiple ∈ BROKEN_TESTS end @testset "GRU GPU AD" begin @@ -120,28 +103,23 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelGRU(GRU(d_in => d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, + compare_finite_diff=false) broken = :gru_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, len, batch_size) - @test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS + @test test_gradients(model, x; test_gpu=true, + compare_finite_diff=false) broken = :gru ∈ BROKEN_TESTS end @testset "GRUv3Cell GPU AD" begin - function loss(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y) - end - d_in, d_out, len, batch_size = 2, 3, 4, 5 r = GRUv3Cell(d_in => d_out) x = [randn(Float32, d_in, batch_size) for _ in 1:len] h = zeros(Float32, d_out) - @test test_gradients(r, x[1], h; test_gpu=true, compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS - @test test_gradients(r, x, h; test_gpu=true, compare_finite_diff=false, loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS + @test test_gradients(r, x[1], h; test_gpu=true, + compare_finite_diff=false) broken = :gruv3cell_single ∈ BROKEN_TESTS + @test test_gradients(r, x, h; test_gpu=true, + compare_finite_diff=false, + loss = recurrent_cell_loss) broken = :gruv3cell_multiple ∈ BROKEN_TESTS end @testset "GRUv3 GPU AD" begin @@ -157,7 +135,9 @@ end d_in, d_out, len, batch_size = 2, 3, 4, 5 model = ModelGRUv3(GRUv3(d_in => d_out), zeros(Float32, d_out)) x_nobatch = randn(Float32, d_in, len) - @test test_gradients(model, x_nobatch; test_gpu=true, compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS + @test test_gradients(model, x_nobatch; test_gpu=true, + compare_finite_diff=false) broken = :gruv3_nobatch ∈ BROKEN_TESTS x = randn(Float32, d_in, len, batch_size) - @test test_gradients(model, x; test_gpu=true, compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS + @test test_gradients(model, x; test_gpu=true, + compare_finite_diff=false) broken = :gruv3 ∈ BROKEN_TESTS end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index f4f4777fd2..864e5dad8e 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -156,37 +156,31 @@ end model = ModelLSTM(LSTM(2 => 4), zeros(Float32, 4), zeros(Float32, 4)) x = rand(Float32, 2, 3, 1) - h, c = model(x) + h = model(x) @test h isa Array{Float32, 3} @test size(h) == (4, 3, 1) - @test c isa Array{Float32, 3} - @test size(c) == (4, 3, 1) - test_gradients(model, x, loss = (m, x) -> mean(m(x)[1])) + test_gradients(model, x) x = rand(Float32, 2, 3) - h, c = model(x) + h = model(x) @test h isa Array{Float32, 2} @test size(h) == (4, 3) - @test c isa Array{Float32, 2} - @test size(c) == (4, 3) test_gradients(model, x, loss = (m, x) -> mean(m(x)[1])) + # test default initial states lstm = model.lstm - h, c = lstm(x) + h = lstm(x) @test h isa Array{Float32, 2} @test size(h) == (4, 3) - @test c isa Array{Float32, 2} - @test size(c) == (4, 3) - + # initial states are zero h0, c0 = Flux.initialstates(lstm) @test h0 ≈ zeros(Float32, 4) @test c0 ≈ zeros(Float32, 4) # no initial state same as zero initial state - h1, c1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4))) + h1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4))) @test h ≈ h1 - @test c ≈ c1 end @testset "GRUCell" begin