From 7d9de60e859c6be2170d419b190dd0eb3546f20d Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Sat, 7 Nov 2020 17:48:45 -0500 Subject: [PATCH 1/3] RNN getproperty deprecation messages rename initial state weights to state0 add 1d & 2D input RNN BPTT gradient test --- src/layers/recurrent.jl | 52 ++++++++++++++++++++++++++++++++++------ test/layers/recurrent.jl | 28 +++++++++++++++++++--- 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 70ba5e6988..01249ec36b 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -31,8 +31,7 @@ mutable struct Recur{T,S} end function (m::Recur)(xs...) - h, y = m.cell(m.state, xs...) - m.state = h + m.state, y = m.cell(m.state, xs...) return y end @@ -51,9 +50,18 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) ``` """ -reset!(m::Recur) = (m.state = m.cell.state) +reset!(m::Recur) = (m.state = m.cell.state0) reset!(m) = foreach(reset!, functor(m)[1]) +function Base.getproperty(m::Recur, sym::Symbol) + if sym === :init + @warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead." + return getfield(m.cell, sym) + else + return getfield(m, sym) + end +end + flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN @@ -63,7 +71,7 @@ struct RNNCell{F,A,V,S} Wi::A Wh::A b::V - state::S + state0::S end RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) = @@ -92,13 +100,22 @@ output fed back into the input each time step. Recur(m::RNNCell) = Recur(m, m.state) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) +function Base.getproperty(m::RNNCell, sym::Symbol) + if sym === :h + @warn "RNNCell field :h has been deprecated for m::RNNCell.state0." + return getfield(m, :state0) + else + return getfield(m, sym) + end +end + # LSTM struct LSTMCell{A,V,S} Wi::A Wh::A b::V - state::S + state0::S end function LSTMCell(in::Integer, out::Integer; @@ -141,13 +158,25 @@ for a good overview of the internals. Recur(m::LSTMCell) = Recur(m, m.state) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) +function Base.getproperty(m::LSTMCell, sym::Symbol) + if sym === :h + @warn "LSTMCell field :h has been deprecated for m::LSTMCell.state0[1]." + return getfield(m, :state0)[1] + elseif sym === :c + @warn "LSTMCell field :c has been deprecated for m::LSTMCell.state0[2]." + return getfield(m, :state0)[2] + else + return getfield(m, sym) + end +end + # GRU struct GRUCell{A,V,S} Wi::A Wh::A b::V - state::S + state0::S end GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) = @@ -159,7 +188,7 @@ function (m::GRUCell)(h, x) r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) - h′ = (1 .- z).*h̃ .+ z.*h + h′ = (1 .- z) .* h̃ .+ z .* h return h′, h′ end @@ -180,6 +209,15 @@ for a good overview of the internals. Recur(m::GRUCell) = Recur(m, m.state) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) +function Base.getproperty(m::GRUCell, sym::Symbol) + if sym === :h + @warn "GRUCell field :h has been deprecated for m::GRUCell.state0." + return getfield(m, :state0) + else + return getfield(m, sym) + end +end + @adjoint function Broadcast.broadcasted(f::Recur, args...) Zygote.∇map(__context__, f, args...) end diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 0691f5929e..f69d12820b 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,6 +1,6 @@ -# Ref FluxML/Flux.jl#1209 +# Ref FluxML/Flux.jl#1209 1D input @testset "BPTT" begin - seq = [rand(Float32, (2,1)) for i = 1:3] + seq = [rand(Float32, 2) for i = 1:3] for r ∈ [RNN,] rnn = r(2,3) Flux.reset!(rnn) @@ -11,7 +11,7 @@ bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * tanh.(rnn.cell.Wi * seq[2] + Wh * tanh.(rnn.cell.Wi * seq[1] + - Wh * rnn.cell.state + Wh * rnn.cell.state0 + rnn.cell.b) + rnn.cell.b) + rnn.cell.b)), @@ -19,3 +19,25 @@ @test grads_seq[rnn.cell.Wh] ≈ bptt[1] end end + +# Ref FluxML/Flux.jl#1209 2D input +@testset "BPTT" begin + seq = [rand(Float32, (2,1)) for i = 1:3] + for r ∈ [RNN,] + rnn = r(2,3) + Flux.reset!(rnn) + grads_seq = gradient(Flux.params(rnn)) do + sum(rnn.(seq)[3]) + end + Flux.reset!(rnn); + bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * + tanh.(rnn.cell.Wi * seq[2] + Wh * + tanh.(rnn.cell.Wi * seq[1] + + Wh * rnn.cell.state0 + + rnn.cell.b) + + rnn.cell.b) + + rnn.cell.b)), + rnn.cell.Wh) + @test grads_seq[rnn.cell.Wh] ≈ bptt[1] + end +end \ No newline at end of file From ba754b789a9f4f71ca0a2f2152509d1cc36a39c0 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Sat, 7 Nov 2020 18:14:49 -0500 Subject: [PATCH 2/3] RNN fix deprecations messages --- src/layers/recurrent.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 01249ec36b..d1eee92eec 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -56,7 +56,7 @@ reset!(m) = foreach(reset!, functor(m)[1]) function Base.getproperty(m::Recur, sym::Symbol) if sym === :init @warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead." - return getfield(m.cell, sym) + return getfield(m.cell, :state0) else return getfield(m, sym) end @@ -97,12 +97,12 @@ end The most basic recurrent layer; essentially acts as a `Dense` layer, but with the output fed back into the input each time step. """ -Recur(m::RNNCell) = Recur(m, m.state) +Recur(m::RNNCell) = Recur(m, m.state0) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) function Base.getproperty(m::RNNCell, sym::Symbol) if sym === :h - @warn "RNNCell field :h has been deprecated for m::RNNCell.state0." + @warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead." return getfield(m, :state0) else return getfield(m, sym) @@ -155,15 +155,15 @@ for a good overview of the internals. """ # Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), # (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) -Recur(m::LSTMCell) = Recur(m, m.state) +Recur(m::LSTMCell) = Recur(m, m.state0) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) function Base.getproperty(m::LSTMCell, sym::Symbol) if sym === :h - @warn "LSTMCell field :h has been deprecated for m::LSTMCell.state0[1]." + @warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead." return getfield(m, :state0)[1] elseif sym === :c - @warn "LSTMCell field :c has been deprecated for m::LSTMCell.state0[2]." + @warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead." return getfield(m, :state0)[2] else return getfield(m, sym) @@ -206,12 +206,12 @@ RNN but generally exhibits a longer memory span over sequences. See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ -Recur(m::GRUCell) = Recur(m, m.state) +Recur(m::GRUCell) = Recur(m, m.state0) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) function Base.getproperty(m::GRUCell, sym::Symbol) if sym === :h - @warn "GRUCell field :h has been deprecated for m::GRUCell.state0." + @warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead." return getfield(m, :state0) else return getfield(m, sym) From 70e4797ac2f85c4d6b3eb01626d417352eecec94 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Sun, 8 Nov 2020 19:07:20 -0500 Subject: [PATCH 3/3] deprecation msg for RNN getproperty --- src/layers/recurrent.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index d1eee92eec..f56d47cbd7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -53,6 +53,7 @@ rnn.state = hidden(rnn.cell) reset!(m::Recur) = (m.state = m.cell.state0) reset!(m) = foreach(reset!, functor(m)[1]) +# TODO remove in v0.13 function Base.getproperty(m::Recur, sym::Symbol) if sym === :init @warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead." @@ -100,6 +101,7 @@ output fed back into the input each time step. Recur(m::RNNCell) = Recur(m, m.state0) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) +# TODO remove in v0.13 function Base.getproperty(m::RNNCell, sym::Symbol) if sym === :h @warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead." @@ -158,6 +160,7 @@ for a good overview of the internals. Recur(m::LSTMCell) = Recur(m, m.state0) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) +# TODO remove in v0.13 function Base.getproperty(m::LSTMCell, sym::Symbol) if sym === :h @warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead." @@ -209,6 +212,7 @@ for a good overview of the internals. Recur(m::GRUCell) = Recur(m, m.state0) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) +# TODO remove in v0.13 function Base.getproperty(m::GRUCell, sym::Symbol) if sym === :h @warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."