Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN deprecations and naming fixes #1390

Merged
merged 4 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ mutable struct Recur{T,S}
end

function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
m.state, y = m.cell(m.state, xs...)
return y
end

Expand All @@ -51,9 +50,19 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
rnn.state = hidden(rnn.cell)
```
"""
reset!(m::Recur) = (m.state = m.cell.state)
reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])

# TODO remove in v0.13
function Base.getproperty(m::Recur, sym::Symbol)
if sym === :init
@warn "Recur field :init has been deprecated. To access initial state weights, use m::Recur.cell.state0 instead."
return getfield(m.cell, :state0)
else
return getfield(m, sym)
end
end

flip(f, xs) = reverse(f.(reverse(xs)))

# Vanilla RNN
Expand All @@ -63,7 +72,7 @@ struct RNNCell{F,A,V,S}
Wi::A
Wh::A
b::V
state::S
state0::S
end

RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
Expand All @@ -89,16 +98,26 @@ end
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
Recur(m::RNNCell) = Recur(m, m.state)
Recur(m::RNNCell) = Recur(m, m.state0)
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))

# TODO remove in v0.13
function Base.getproperty(m::RNNCell, sym::Symbol)
if sym === :h
@warn "RNNCell field :h has been deprecated. Use m::RNNCell.state0 instead."
return getfield(m, :state0)
else
return getfield(m, sym)
end
end

# LSTM

struct LSTMCell{A,V,S}
Wi::A
Wh::A
b::V
state::S
state0::S
end

function LSTMCell(in::Integer, out::Integer;
Expand Down Expand Up @@ -138,16 +157,29 @@ for a good overview of the internals.
"""
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
Recur(m::LSTMCell) = Recur(m, m.state)
Recur(m::LSTMCell) = Recur(m, m.state0)
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))

# TODO remove in v0.13
function Base.getproperty(m::LSTMCell, sym::Symbol)
if sym === :h
@warn "LSTMCell field :h has been deprecated. Use m::LSTMCell.state0[1] instead."
return getfield(m, :state0)[1]
elseif sym === :c
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
return getfield(m, :state0)[2]
else
return getfield(m, sym)
end
end

# GRU

struct GRUCell{A,V,S}
Wi::A
Wh::A
b::V
state::S
state0::S
end

GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
Expand All @@ -159,7 +191,7 @@ function (m::GRUCell)(h, x)
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z).*h̃ .+ z.*h
h′ = (1 .- z) .* h̃ .+ z .* h
return h′, h′
end

Expand All @@ -177,9 +209,19 @@ RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
Recur(m::GRUCell) = Recur(m, m.state)
Recur(m::GRUCell) = Recur(m, m.state0)
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

# TODO remove in v0.13
function Base.getproperty(m::GRUCell, sym::Symbol)
if sym === :h
@warn "GRUCell field :h has been deprecated. Use m::GRUCell.state0 instead."
return getfield(m, :state0)
else
return getfield(m, sym)
end
end

@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
28 changes: 25 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ref FluxML/Flux.jl#1209
# Ref FluxML/Flux.jl#1209 1D input
@testset "BPTT" begin
seq = [rand(Float32, (2,1)) for i = 1:3]
seq = [rand(Float32, 2) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
Expand All @@ -11,11 +11,33 @@
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.cell.state
Wh * rnn.cell.state0
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] ≈ bptt[1]
end
end

# Ref FluxML/Flux.jl#1209 2D input
@testset "BPTT" begin
seq = [rand(Float32, (2,1)) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
end
Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.cell.state0
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
@test grads_seq[rnn.cell.Wh] ≈ bptt[1]
end
end