Skip to content

Commit

Permalink
Merge #1390
Browse files Browse the repository at this point in the history
1390: RNN deprecations and naming fixes r=CarloLucibello a=jeremiedb

Following discussion in #1367 

This PR brings the disambiguation between initial state parameters named `state0` in the rnn cells with the state of the rnn chain named `state` in the `Recur` struct. 

Add getproperty with deprecation messages to access the legacy `h` and `c` in the rnn cells as well as the `init` field in `Recu` (which now points to `recur.cell.state0`). 

Include both 1D and 2D input dimensions to the basic BPTT test.   

Co-authored-by: jeremie.db <[email protected]>
  • Loading branch information
bors[bot] and jeremiedb authored Nov 9, 2020
2 parents 09764f8 + 70e4797 commit 8b53033
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 13 deletions.
62 changes: 52 additions & 10 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ mutable struct Recur{T,S}
end

function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
m.state, y = m.cell(m.state, xs...)
return y
end

Expand All @@ -51,9 +50,19 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
rnn.state = hidden(rnn.cell)
```
"""
reset!(m::Recur) = (m.state = m.cell.state)
reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])

# TODO remove in v0.13
function Base.getproperty(m::Recur, sym::Symbol)
if sym === :init
@warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead."
return getfield(m.cell, :state0)
else
return getfield(m, sym)
end
end

flip(f, xs) = reverse(f.(reverse(xs)))

# Vanilla RNN
Expand All @@ -63,7 +72,7 @@ struct RNNCell{F,A,V,S}
Wi::A
Wh::A
b::V
state::S
state0::S
end

RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
Expand All @@ -89,16 +98,26 @@ end
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
Recur(m::RNNCell) = Recur(m, m.state)
Recur(m::RNNCell) = Recur(m, m.state0)
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))

# TODO remove in v0.13
function Base.getproperty(m::RNNCell, sym::Symbol)
if sym === :h
@warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead."
return getfield(m, :state0)
else
return getfield(m, sym)
end
end

# LSTM

struct LSTMCell{A,V,S}
Wi::A
Wh::A
b::V
state::S
state0::S
end

function LSTMCell(in::Integer, out::Integer;
Expand Down Expand Up @@ -138,16 +157,29 @@ for a good overview of the internals.
"""
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
Recur(m::LSTMCell) = Recur(m, m.state)
Recur(m::LSTMCell) = Recur(m, m.state0)
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))

# TODO remove in v0.13
function Base.getproperty(m::LSTMCell, sym::Symbol)
if sym === :h
@warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead."
return getfield(m, :state0)[1]
elseif sym === :c
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
return getfield(m, :state0)[2]
else
return getfield(m, sym)
end
end

# GRU

struct GRUCell{A,V,S}
Wi::A
Wh::A
b::V
state::S
state0::S
end

GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
Expand All @@ -159,7 +191,7 @@ function (m::GRUCell)(h, x)
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z).*.+ z.*h
h′ = (1 .- z) .* .+ z .* h
return h′, h′
end

Expand All @@ -177,9 +209,19 @@ RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
Recur(m::GRUCell) = Recur(m, m.state)
Recur(m::GRUCell) = Recur(m, m.state0)
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

# TODO remove in v0.13
function Base.getproperty(m::GRUCell, sym::Symbol)
if sym === :h
@warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."
return getfield(m, :state0)
else
return getfield(m, sym)
end
end

@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
28 changes: 25 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ref FluxML/Flux.jl#1209
# Ref FluxML/Flux.jl#1209 1D input
@testset "BPTT" begin
seq = [rand(Float32, (2,1)) for i = 1:3]
seq = [rand(Float32, 2) for i = 1:3]
for r [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
Expand All @@ -11,11 +11,33 @@
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.cell.state
Wh * rnn.cell.state0
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] bptt[1]
end
end

# Ref FluxML/Flux.jl#1209 2D input
@testset "BPTT" begin
seq = [rand(Float32, (2,1)) for i = 1:3]
for r [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
end
Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.cell.state0
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] bptt[1]
end
end

0 comments on commit 8b53033

Please sign in to comment.