Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN update to drop CUDNN, fix LSTM bug and output type stability #1367

Merged
merged 23 commits into from
Nov 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
docs/build/
docs/site/
deps
# Manifest.toml
.vscode
# Manifest.toml
7 changes: 5 additions & 2 deletions src/cuda/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module CUDAint

using ..CUDA

using CUDA: CUDNN
include("curnn.jl")

import ..Flux: Flux
import Zygote
using Zygote: @adjoint

include("cudnn.jl")

end
89 changes: 0 additions & 89 deletions src/cuda/curnn.jl

This file was deleted.

55 changes: 25 additions & 30 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
Expand All @@ -24,21 +25,19 @@ rnn.(1:10) # apply to a sequence
rnn.state # 60
```
"""
mutable struct Recur{T}
mutable struct Recur{T,S}
cell::T
init
state
state::S
end

Recur(m, h = hidden(m)) = Recur(m, h, h)

function (m::Recur)(xs...)
h, y = m.cell(m.state, xs...)
m.state = h
return y
end

@functor Recur cell, init
@functor Recur
trainable(a::Recur) = (a.cell,)

Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")

Expand All @@ -52,34 +51,30 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to:
rnn.state = hidden(rnn.cell)
```
"""
reset!(m::Recur) = (m.state = m.init)
reset!(m::Recur) = (m.state = m.cell.state)
reset!(m) = foreach(reset!, functor(m)[1])

flip(f, xs) = reverse(f.(reverse(xs)))

# Vanilla RNN

mutable struct RNNCell{F,A,V}
struct RNNCell{F,A,V,S}
σ::F
Wi::A
Wh::A
b::V
h::V
state::S
end

RNNCell(in::Integer, out::Integer, σ = tanh;
init = glorot_uniform) =
RNNCell(σ, init(out, in), init(out, out),
init(out), zeros(out))
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell)(h, x)
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
return h, h
end

hidden(m::RNNCell) = m.h

@functor RNNCell

function Base.show(io::IO, l::RNNCell)
Expand All @@ -94,22 +89,23 @@ end
The most basic recurrent layer; essentially acts as a `Dense` layer, but with the
output fed back into the input each time step.
"""
Recur(m::RNNCell) = Recur(m, m.state)
RNN(a...; ka...) = Recur(RNNCell(a...; ka...))

# LSTM

mutable struct LSTMCell{A,V}
struct LSTMCell{A,V,S}
Wi::A
Wh::A
b::V
h::V
c::V
state::S
end

function LSTMCell(in::Integer, out::Integer;
init = glorot_uniform)
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
zeros(out), zeros(out))
init = glorot_uniform,
initb = zeros,
init_state = zeros)
cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1)))
cell.b[gate(out, 2)] .= 1
return cell
end
Expand All @@ -126,8 +122,6 @@ function (m::LSTMCell)((h, c), x)
return (h′, c), h′
end

hidden(m::LSTMCell) = (m.h, m.c)

@functor LSTMCell

Base.show(io::IO, l::LSTMCell) =
Expand All @@ -142,20 +136,22 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)),
# (zeros(length(m.b)÷4), zeros(length(m.b)÷4)))
Recur(m::LSTMCell) = Recur(m, m.state)
LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...))

# GRU

mutable struct GRUCell{A,V}
struct GRUCell{A,V,S}
Wi::A
Wh::A
b::V
h::V
state::S
end

GRUCell(in, out; init = glorot_uniform) =
GRUCell(init(out * 3, in), init(out * 3, out),
init(out * 3), zeros(out))
GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell)(h, x)
b, o = m.b, size(h, 1)
Expand All @@ -167,8 +163,6 @@ function (m::GRUCell)(h, x)
return h′, h′
end

hidden(m::GRUCell) = m.h

@functor GRUCell

Base.show(io::IO, l::GRUCell) =
Expand All @@ -183,6 +177,7 @@ RNN but generally exhibits a longer memory span over sequences.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
Recur(m::GRUCell) = Recur(m, m.state)
GRU(a...; ka...) = Recur(GRUCell(a...; ka...))

@adjoint function Broadcast.broadcasted(f::Recur, args...)
Expand Down
8 changes: 3 additions & 5 deletions test/cuda/curnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using Flux: pullback
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test x isa CuArray
@test_broken θ[m.cell.Wi] isa CuArray
@test θ[m.cell.Wi] isa CuArray
@test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
end

Expand All @@ -20,17 +20,15 @@ end
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
rand(10) :
rand(10, batch_size)
rand(Float32, 10) :
rand(Float32, 10, batch_size)
cux = gpu(x)

y, back = pullback((r, x) -> r(x), rnn, x)
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)

@test y ≈ collect(cuy)

@test haskey(Flux.CUDAint.descs, curnn.cell)

ȳ = randn(size(y))
m̄, x̄ = back(ȳ)
cum̄, cux̄ = cuback(gpu(ȳ))
Expand Down
4 changes: 2 additions & 2 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ref FluxML/Flux.jl#1209
@testset "BPTT" begin
seq = [rand(2) for i = 1:3]
seq = [rand(Float32, (2,1)) for i = 1:3]
for r ∈ [RNN,]
rnn = r(2,3)
Flux.reset!(rnn)
Expand All @@ -11,7 +11,7 @@
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.init
Wh * rnn.cell.state
+ rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ end
m = Dense(10, 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)]

# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)]

# Self-referential array. Just want params, no stack overflow pls.
r = Any[nothing,m]
r[1] = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)]
end

@testset "Basic Stacking" begin
Expand Down