From 96e2b5e31b178f7aa491496a71a84a8ea8baf77f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 12 Dec 2024 17:23:27 +0100 Subject: [PATCH] change cells output --- src/layers/recurrent.jl | 61 +++++++++------------ test/layers/recurrent.jl | 115 +++++++++++++++++---------------------- 2 files changed, 76 insertions(+), 100 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 9386a3fc2d..dbef6f9932 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,13 +1,8 @@ -out_from_state(state) = state -out_from_state(state::Tuple) = state[1] - -function scan(cell, x, state0) - state = state0 +function scan(cell, x, state) y = [] for x_t in eachslice(x, dims = 2) - state = cell(x_t, state) - out = out_from_state(state) - y = vcat(y, [out]) + yt, state = cell(x_t, state) + y = vcat(y, [yt]) end return stack(y, dims = 2) end @@ -26,7 +21,7 @@ In the forward pass, implements the function ```math h^\prime = \sigma(W_i x + W_h h + b) ``` -and returns `h'`. +Returns a tuple `(out, state)`, where both element are given by `h'`. See [`RNN`](@ref) for a layer that processes entire sequences. @@ -48,6 +43,8 @@ The arguments of the forward pass are: - `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). +Returns a tuple `(out, state)`, where both elements are given by the updated state `h'`. + # Examples ```julia @@ -64,10 +61,10 @@ h = zeros(Float32, 5) ŷ = [] for x_t in x - h = r(x_t, h) - ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation - # is not automatic differentiation friendly yet. - # Can use `y = vcat(y, [h])` as an alternative. + yt, h = r(x_t, h) + ŷ = [ŷ..., yt] # Cannot use `push!(ŷ, h)` here since mutation + # is not automatic differentiation friendly yet. + # Can use `y = vcat(y, [h])` as an alternative. end h # The final hidden state @@ -107,28 +104,25 @@ res = rnn(x, h0) initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2)) function RNNCell( - (in, out)::Pair, - σ = tanh; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true, -) + (in, out)::Pair, + σ = tanh; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, + ) Wi = init_kernel(out, in) Wh = init_recurrent_kernel(out, out) b = create_bias(Wi, bias, size(Wi, 1)) return RNNCell(σ, Wi, Wh, b) end -function (rnn::RNNCell)(x::AbstractVecOrMat) - state = initialstates(rnn) - return rnn(x, state) -end +(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn)) function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) _size_check(m, x, 1 => size(m.Wi, 2)) σ = NNlib.fast_act(m.σ, x) h = σ.(m.Wi * x .+ m.Wh * h .+ m.bias) - return h + return h, h end function Base.show(io::IO, m::RNNCell) @@ -220,10 +214,7 @@ function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) return RNN(cell) end -function (rnn::RNN)(x::AbstractArray) - state = initialstates(rnn) - return rnn(x, state) -end +(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn)) function (m::RNN)(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 @@ -325,9 +316,9 @@ function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) b = m.bias g = m.Wi * x .+ m.Wh * h .+ b input, forget, cell, output = chunk(g, 4; dims = 1) - c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) - h′ = @. sigmoid_fast(output) * tanh_fast(c′) - return h′, c′ + c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) + h = @. sigmoid_fast(output) * tanh_fast(c) + return h, (h, c) end Base.show(io::IO, m::LSTMCell) = @@ -509,8 +500,8 @@ function (m::GRUCell)(x::AbstractVecOrMat, h) r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) h̃ = @. tanh_fast(gxs[3] + r * ghs[3] + bs[3]) - h′ = @. (1 - z) * h̃ + z * h - return h′ + h = @. (1 - z) * h̃ + z * h + return h, h end Base.show(io::IO, m::GRUCell) = @@ -664,8 +655,8 @@ function (m::GRUv3Cell)(x::AbstractVecOrMat, h) r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1]) z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2]) h̃ = tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3]) - h′ = @. (1 - z) * h̃ + z * h - return h′ + h = @. (1 - z) * h̃ + z * h + return h, h end Base.show(io::IO, m::GRUv3Cell) = diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 864e5dad8e..ed68527ed1 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,36 +1,37 @@ - -@testset "RNNCell" begin - function loss1(r, x, h) - for x_t in x - h = r(x_t, h) - end - return mean(h.^2) +function cell_loss1(r, x, state) + for x_t in x + _, state = r(x_t, state) end + return mean(state[1]) +end - function loss2(r, x, h) - y = [r(x_t, h) for x_t in x] - return sum(mean, y) - end +function cell_loss2(r, x, state) + y = [r(x_t, state)[1] for x_t in x] + return sum(mean, y) +end - function loss3(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = [y..., h] - end - return sum(mean, y) +function cell_loss3(r, x, state) + y = [] + for x_t in x + y_t, state = r(x_t, state) + y = [y..., y_t] end + return sum(mean, y) +end - function loss4(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y.^2) +function cell_loss4(r, x, sate) + y = [] + for x_t in x + y_t, state = r(x_t, state) + y = vcat(y, [y_t]) end + y = stack(y, dims=2) # [D, L] or [D, L, B] + return mean(y.^2) +end + +@testset "RNNCell" begin + r = RNNCell(3 => 5) @test length(Flux.trainables(r)) == 3 # An input sequence of length 6 and batch size 4. @@ -38,31 +39,33 @@ # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h, loss=loss1) # for loop - test_gradients(r, x, h, loss=loss2) # comprehension - test_gradients(r, x, h, loss=loss3) # splat - test_gradients(r, x, h, loss=loss4) # vcat and stack + test_gradients(r, x, h, loss=cell_loss1) # for loop + test_gradients(r, x, h, loss=cell_loss2) # comprehension + test_gradients(r, x, h, loss=cell_loss3) # splat + test_gradients(r, x, h, loss=cell_loss4) # vcat and stack # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) # no initial state same as zero initial state - @test r(x[1]) ≈ r(x[1], zeros(Float32, 5)) + out, state = r(x[1]) + @test out === state + @test out ≈ r(x[1], zeros(Float32, 5))[1] # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h, loss=loss4) + test_gradients(r, x, h, loss=cell_loss4) # The input sequence has no batch dimension. x = [rand(Float32, 3) for _ in 1:6] h = randn(Float32, 5) - test_gradients(r, x, h, loss=loss4) + test_gradients(r, x, h, loss=cell_loss4) # No Bias r = RNNCell(3 => 5, bias=false) @test length(Flux.trainables(r)) == 2 - test_gradients(r, x, h, loss=loss4) + test_gradients(r, x, h, loss=cell_loss4) end @testset "RNN" begin @@ -99,32 +102,20 @@ end @testset "LSTMCell" begin - function loss(r, x, hc) - h, c = hc - h′ = [] - c′ = [] - for x_t in x - h, c = r(x_t, (h, c)) - h′ = vcat(h′, [h]) - c′ = [c′..., c] - end - hnew = stack(h′, dims=2) - cnew = stack(c′, dims=2) - return mean(hnew.^2) + mean(cnew.^2) - end - cell = LSTMCell(3 => 5) @test length(Flux.trainables(cell)) == 3 x = [rand(Float32, 3, 4) for _ in 1:6] h = zeros(Float32, 5, 4) c = zeros(Float32, 5, 4) - hnew, cnew = cell(x[1], (h, c)) + out, state = cell(x[1], (h, c)) + hnew, cnew = state + @test out === hnew @test hnew isa Matrix{Float32} @test cnew isa Matrix{Float32} @test size(hnew) == (5, 4) @test size(cnew) == (5, 4) test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1])) - test_gradients(cell, x, (h, c), loss = loss) + test_gradients(cell, x, (h, c), loss = cell_loss4) # initial states are zero h0, c0 = Flux.initialstates(cell) @@ -132,8 +123,8 @@ end @test c0 ≈ zeros(Float32, 5) # no initial state same as zero initial state - hnew1, cnew1 = cell(x[1]) - hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) + _, (hnew1, cnew1) = cell(x[1]) + _, (hnew2, cnew2) = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5))) @test hnew1 ≈ hnew2 @test cnew1 ≈ cnew2 @@ -184,16 +175,6 @@ end end @testset "GRUCell" begin - function loss(r, x, h) - y = [] - for x_t in x - h = r(x_t, h) - y = vcat(y, [h]) - end - y = stack(y, dims=2) # [D, L] or [D, L, B] - return mean(y.^2) - end - r = GRUCell(3 => 5) @test length(Flux.trainables(r)) == 3 # An input sequence of length 6 and batch size 4. @@ -201,7 +182,9 @@ end # Initial State is a single vector h = randn(Float32, 5) - test_gradients(r, x, h; loss) + out, state = r(x[1], h) + @test out === state + test_gradients(r, x, h; loss = cell_loss4) # initial states are zero @test Flux.initialstates(r) ≈ zeros(Float32, 5) @@ -211,12 +194,12 @@ end # Now initial state has a batch dimension. h = randn(Float32, 5, 4) - test_gradients(r, x, h; loss) + test_gradients(r, x, h; loss = cell_loss4) # The input sequence has no batch dimension. x = [rand(Float32, 3) for _ in 1:6] h = randn(Float32, 5) - test_gradients(r, x, h; loss) + test_gradients(r, x, h; loss = cell_loss4) # No Bias r = GRUCell(3 => 5, bias=false) @@ -262,6 +245,8 @@ end # Initial State is a single vector h = randn(Float32, 5) + out, state = r(x, h) + @test out === state test_gradients(r, x, h) # initial states are zero