diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 70ba5e6988..f56d47cbd7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -31,8 +31,7 @@ mutable struct Recur{T,S} end function (m::Recur)(xs...) - h, y = m.cell(m.state, xs...) - m.state = h + m.state, y = m.cell(m.state, xs...) return y end @@ -51,9 +50,19 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) ``` """ -reset!(m::Recur) = (m.state = m.cell.state) +reset!(m::Recur) = (m.state = m.cell.state0) reset!(m) = foreach(reset!, functor(m)[1]) +# TODO remove in v0.13 +function Base.getproperty(m::Recur, sym::Symbol) + if sym === :init + @warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead." + return getfield(m.cell, :state0) + else + return getfield(m, sym) + end +end + flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN @@ -63,7 +72,7 @@ struct RNNCell{F,A,V,S} Wi::A Wh::A b::V - state::S + state0::S end RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) = @@ -89,16 +98,26 @@ end The most basic recurrent layer; essentially acts as a `Dense` layer, but with the output fed back into the input each time step. """ -Recur(m::RNNCell) = Recur(m, m.state) +Recur(m::RNNCell) = Recur(m, m.state0) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) +# TODO remove in v0.13 +function Base.getproperty(m::RNNCell, sym::Symbol) + if sym === :h + @warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead." + return getfield(m, :state0) + else + return getfield(m, sym) + end +end + # LSTM struct LSTMCell{A,V,S} Wi::A Wh::A b::V - state::S + state0::S end function LSTMCell(in::Integer, out::Integer; @@ -138,16 +157,29 @@ for a good overview of the internals. """ # Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), # (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) -Recur(m::LSTMCell) = Recur(m, m.state) +Recur(m::LSTMCell) = Recur(m, m.state0) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) +# TODO remove in v0.13 +function Base.getproperty(m::LSTMCell, sym::Symbol) + if sym === :h + @warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead." + return getfield(m, :state0)[1] + elseif sym === :c + @warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead." + return getfield(m, :state0)[2] + else + return getfield(m, sym) + end +end + # GRU struct GRUCell{A,V,S} Wi::A Wh::A b::V - state::S + state0::S end GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) = @@ -159,7 +191,7 @@ function (m::GRUCell)(h, x) r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) - h′ = (1 .- z).*h̃ .+ z.*h + h′ = (1 .- z) .* h̃ .+ z .* h return h′, h′ end @@ -177,9 +209,19 @@ RNN but generally exhibits a longer memory span over sequences. See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ -Recur(m::GRUCell) = Recur(m, m.state) +Recur(m::GRUCell) = Recur(m, m.state0) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) +# TODO remove in v0.13 +function Base.getproperty(m::GRUCell, sym::Symbol) + if sym === :h + @warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead." + return getfield(m, :state0) + else + return getfield(m, sym) + end +end + @adjoint function Broadcast.broadcasted(f::Recur, args...) Zygote.∇map(__context__, f, args...) end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 0691f5929e..f69d12820b 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,6 +1,6 @@ -# Ref FluxML/Flux.jl#1209 +# Ref FluxML/Flux.jl#1209 1D input @testset "BPTT" begin - seq = [rand(Float32, (2,1)) for i = 1:3] + seq = [rand(Float32, 2) for i = 1:3] for r ∈ [RNN,] rnn = r(2,3) Flux.reset!(rnn) @@ -11,7 +11,7 @@ bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * tanh.(rnn.cell.Wi * seq[2] + Wh * tanh.(rnn.cell.Wi * seq[1] + - Wh * rnn.cell.state + Wh * rnn.cell.state0 + rnn.cell.b) + rnn.cell.b) + rnn.cell.b)), @@ -19,3 +19,25 @@ @test grads_seq[rnn.cell.Wh] ≈ bptt[1] end end + +# Ref FluxML/Flux.jl#1209 2D input +@testset "BPTT" begin + seq = [rand(Float32, (2,1)) for i = 1:3] + for r ∈ [RNN,] + rnn = r(2,3) + Flux.reset!(rnn) + grads_seq = gradient(Flux.params(rnn)) do + sum(rnn.(seq)[3]) + end + Flux.reset!(rnn); + bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * + tanh.(rnn.cell.Wi * seq[2] + Wh * + tanh.(rnn.cell.Wi * seq[1] + + Wh * rnn.cell.state0 + + rnn.cell.b) + + rnn.cell.b) + + rnn.cell.b)), + rnn.cell.Wh) + @test grads_seq[rnn.cell.Wh] ≈ bptt[1] + end +end \ No newline at end of file