Skip to content

Commit

Permalink
change cells output
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 12, 2024
1 parent 2bbd8b3 commit 96e2b5e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 100 deletions.
61 changes: 26 additions & 35 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
out_from_state(state) = state
out_from_state(state::Tuple) = state[1]

function scan(cell, x, state0)
state = state0
function scan(cell, x, state)
y = []
for x_t in eachslice(x, dims = 2)
state = cell(x_t, state)
out = out_from_state(state)
y = vcat(y, [out])
yt, state = cell(x_t, state)
y = vcat(y, [yt])
end
return stack(y, dims = 2)
end
Expand All @@ -26,7 +21,7 @@ In the forward pass, implements the function
```math
h^\prime = \sigma(W_i x + W_h h + b)
```
and returns `h'`.
Returns a tuple `(out, state)`, where both element are given by `h'`.
See [`RNN`](@ref) for a layer that processes entire sequences.
Expand All @@ -48,6 +43,8 @@ The arguments of the forward pass are:
- `h`: The hidden state of the RNN. It should be a vector of size `out` or a matrix of size `out x batch_size`.
If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref).
Returns a tuple `(out, state)`, where both elements are given by the updated state `h'`.
# Examples
```julia
Expand All @@ -64,10 +61,10 @@ h = zeros(Float32, 5)
ŷ = []
for x_t in x
h = r(x_t, h)
ŷ = [ŷ..., h] # Cannot use `push!(ŷ, h)` here since mutation
# is not automatic differentiation friendly yet.
# Can use `y = vcat(y, [h])` as an alternative.
yt, h = r(x_t, h)
ŷ = [ŷ..., yt] # Cannot use `push!(ŷ, h)` here since mutation
# is not automatic differentiation friendly yet.
# Can use `y = vcat(y, [h])` as an alternative.
end
h # The final hidden state
Expand Down Expand Up @@ -107,28 +104,25 @@ res = rnn(x, h0)
initialstates(rnn::RNNCell) = zeros_like(rnn.Wh, size(rnn.Wh, 2))

function RNNCell(
(in, out)::Pair,
σ = tanh;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
(in, out)::Pair,
σ = tanh;
init_kernel = glorot_uniform,
init_recurrent_kernel = glorot_uniform,
bias = true,
)
Wi = init_kernel(out, in)
Wh = init_recurrent_kernel(out, out)
b = create_bias(Wi, bias, size(Wi, 1))
return RNNCell(σ, Wi, Wh, b)
end

function (rnn::RNNCell)(x::AbstractVecOrMat)
state = initialstates(rnn)
return rnn(x, state)
end
(rnn::RNNCell)(x::AbstractVecOrMat) = rnn(x, initialstates(rnn))

function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat)
_size_check(m, x, 1 => size(m.Wi, 2))
σ = NNlib.fast_act(m.σ, x)
h = σ.(m.Wi * x .+ m.Wh * h .+ m.bias)
return h
return h, h
end

function Base.show(io::IO, m::RNNCell)
Expand Down Expand Up @@ -220,10 +214,7 @@ function RNN((in, out)::Pair, σ = tanh; cell_kwargs...)
return RNN(cell)
end

function (rnn::RNN)(x::AbstractArray)
state = initialstates(rnn)
return rnn(x, state)
end
(rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn))

function (m::RNN)(x::AbstractArray, h)
@assert ndims(x) == 2 || ndims(x) == 3
Expand Down Expand Up @@ -325,9 +316,9 @@ function (m::LSTMCell)(x::AbstractVecOrMat, (h, c))
b = m.bias
g = m.Wi * x .+ m.Wh * h .+ b
input, forget, cell, output = chunk(g, 4; dims = 1)
c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h = @. sigmoid_fast(output) * tanh_fast(c)
return h′, c′
c = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell)
h = @. sigmoid_fast(output) * tanh_fast(c)
return h, (h, c)
end

Base.show(io::IO, m::LSTMCell) =
Expand Down Expand Up @@ -509,8 +500,8 @@ function (m::GRUCell)(x::AbstractVecOrMat, h)
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= @. tanh_fast(gxs[3] + r * ghs[3] + bs[3])
h = @. (1 - z) *+ z * h
return h
h = @. (1 - z) *+ z * h
return h, h
end

Base.show(io::IO, m::GRUCell) =
Expand Down Expand Up @@ -664,8 +655,8 @@ function (m::GRUv3Cell)(x::AbstractVecOrMat, h)
r = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
z = @. sigmoid_fast(gxs[2] + ghs[2] + bs[2])
= tanh_fast.(gxs[3] .+ (m.Wh_h̃ * (r .* h)) .+ bs[3])
h = @. (1 - z) *+ z * h
return h
h = @. (1 - z) *+ z * h
return h, h
end

Base.show(io::IO, m::GRUv3Cell) =
Expand Down
115 changes: 50 additions & 65 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,71 @@

@testset "RNNCell" begin
function loss1(r, x, h)
for x_t in x
h = r(x_t, h)
end
return mean(h.^2)
function cell_loss1(r, x, state)
for x_t in x
_, state = r(x_t, state)
end
return mean(state[1])
end

function loss2(r, x, h)
y = [r(x_t, h) for x_t in x]
return sum(mean, y)
end
function cell_loss2(r, x, state)
y = [r(x_t, state)[1] for x_t in x]
return sum(mean, y)
end

function loss3(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = [y..., h]
end
return sum(mean, y)
function cell_loss3(r, x, state)
y = []
for x_t in x
y_t, state = r(x_t, state)
y = [y..., y_t]
end
return sum(mean, y)
end

function loss4(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y.^2)
function cell_loss4(r, x, sate)
y = []
for x_t in x
y_t, state = r(x_t, state)
y = vcat(y, [y_t])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y.^2)
end


@testset "RNNCell" begin

r = RNNCell(3 => 5)
@test length(Flux.trainables(r)) == 3
# An input sequence of length 6 and batch size 4.
x = [rand(Float32, 3, 4) for _ in 1:6]

# Initial State is a single vector
h = randn(Float32, 5)
test_gradients(r, x, h, loss=loss1) # for loop
test_gradients(r, x, h, loss=loss2) # comprehension
test_gradients(r, x, h, loss=loss3) # splat
test_gradients(r, x, h, loss=loss4) # vcat and stack
test_gradients(r, x, h, loss=cell_loss1) # for loop
test_gradients(r, x, h, loss=cell_loss2) # comprehension
test_gradients(r, x, h, loss=cell_loss3) # splat
test_gradients(r, x, h, loss=cell_loss4) # vcat and stack

# initial states are zero
@test Flux.initialstates(r) zeros(Float32, 5)

# no initial state same as zero initial state
@test r(x[1]) r(x[1], zeros(Float32, 5))
out, state = r(x[1])
@test out === state
@test out r(x[1], zeros(Float32, 5))[1]

# Now initial state has a batch dimension.
h = randn(Float32, 5, 4)
test_gradients(r, x, h, loss=loss4)
test_gradients(r, x, h, loss=cell_loss4)

# The input sequence has no batch dimension.
x = [rand(Float32, 3) for _ in 1:6]
h = randn(Float32, 5)
test_gradients(r, x, h, loss=loss4)
test_gradients(r, x, h, loss=cell_loss4)


# No Bias
r = RNNCell(3 => 5, bias=false)
@test length(Flux.trainables(r)) == 2
test_gradients(r, x, h, loss=loss4)
test_gradients(r, x, h, loss=cell_loss4)
end

@testset "RNN" begin
Expand Down Expand Up @@ -99,41 +102,29 @@ end

@testset "LSTMCell" begin

function loss(r, x, hc)
h, c = hc
h′ = []
c′ = []
for x_t in x
h, c = r(x_t, (h, c))
h′ = vcat(h′, [h])
c′ = [c′..., c]
end
hnew = stack(h′, dims=2)
cnew = stack(c′, dims=2)
return mean(hnew.^2) + mean(cnew.^2)
end

cell = LSTMCell(3 => 5)
@test length(Flux.trainables(cell)) == 3
x = [rand(Float32, 3, 4) for _ in 1:6]
h = zeros(Float32, 5, 4)
c = zeros(Float32, 5, 4)
hnew, cnew = cell(x[1], (h, c))
out, state = cell(x[1], (h, c))
hnew, cnew = state
@test out === hnew
@test hnew isa Matrix{Float32}
@test cnew isa Matrix{Float32}
@test size(hnew) == (5, 4)
@test size(cnew) == (5, 4)
test_gradients(cell, x[1], (h, c), loss = (m, x, hc) -> mean(m(x, hc)[1]))
test_gradients(cell, x, (h, c), loss = loss)
test_gradients(cell, x, (h, c), loss = cell_loss4)

# initial states are zero
h0, c0 = Flux.initialstates(cell)
@test h0 zeros(Float32, 5)
@test c0 zeros(Float32, 5)

# no initial state same as zero initial state
hnew1, cnew1 = cell(x[1])
hnew2, cnew2 = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5)))
_, (hnew1, cnew1) = cell(x[1])
_, (hnew2, cnew2) = cell(x[1], (zeros(Float32, 5), zeros(Float32, 5)))
@test hnew1 hnew2
@test cnew1 cnew2

Expand Down Expand Up @@ -184,24 +175,16 @@ end
end

@testset "GRUCell" begin
function loss(r, x, h)
y = []
for x_t in x
h = r(x_t, h)
y = vcat(y, [h])
end
y = stack(y, dims=2) # [D, L] or [D, L, B]
return mean(y.^2)
end

r = GRUCell(3 => 5)
@test length(Flux.trainables(r)) == 3
# An input sequence of length 6 and batch size 4.
x = [rand(Float32, 3, 4) for _ in 1:6]

# Initial State is a single vector
h = randn(Float32, 5)
test_gradients(r, x, h; loss)
out, state = r(x[1], h)
@test out === state
test_gradients(r, x, h; loss = cell_loss4)

# initial states are zero
@test Flux.initialstates(r) zeros(Float32, 5)
Expand All @@ -211,12 +194,12 @@ end

# Now initial state has a batch dimension.
h = randn(Float32, 5, 4)
test_gradients(r, x, h; loss)
test_gradients(r, x, h; loss = cell_loss4)

# The input sequence has no batch dimension.
x = [rand(Float32, 3) for _ in 1:6]
h = randn(Float32, 5)
test_gradients(r, x, h; loss)
test_gradients(r, x, h; loss = cell_loss4)

# No Bias
r = GRUCell(3 => 5, bias=false)
Expand Down Expand Up @@ -262,6 +245,8 @@ end

# Initial State is a single vector
h = randn(Float32, 5)
out, state = r(x, h)
@test out === state
test_gradients(r, x, h)

# initial states are zero
Expand Down

0 comments on commit 96e2b5e

Please sign in to comment.