diff --git a/.gitignore b/.gitignore index df70b5309a..b042b95f75 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ docs/build/ docs/site/ deps -# Manifest.toml \ No newline at end of file +.vscode +# Manifest.toml diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 7be752a1dd..6c8096f978 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -1,9 +1,12 @@ module CUDAint using ..CUDA - using CUDA: CUDNN -include("curnn.jl") + +import ..Flux: Flux +import Zygote +using Zygote: @adjoint + include("cudnn.jl") end diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl deleted file mode 100644 index f4f9cb4f97..0000000000 --- a/src/cuda/curnn.jl +++ /dev/null @@ -1,89 +0,0 @@ -import ..Flux: Flux, relu - -CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} -CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} -CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} -CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} - -function CUDNN.RNNDesc(m::CuRNNs{T}) where T - h, i = length(m.h), size(m.Wi, 2) - mode = m isa CuRNN ? - (m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) : - m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM - r = CUDNN.RNNDesc{T}(mode, i, h) - return r -end - -const descs = WeakKeyDict() - -function desc(rnn) - d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) - CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) - return d -end - -import Zygote -using Zygote: @adjoint - -function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = CUDNN.forward(desc(m), x, h) - return h′, y -end - -function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = CUDNN.forward(desc(m), x, h) - return h′, y -end - -function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) - return (h′, c′), y -end - -(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) - -coerce_cuda(x::Union{CuArray,Nothing}) = x -coerce_cuda(x::Tuple) = coerce_cuda.(x) - -coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0) - -function struct_grad!(cx::Zygote.Context, x, x̄) - for f in fieldnames(typeof(x)) - Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) - end - dx = Zygote.grad_mut(cx, x) - dx[] = Zygote.accum(dx[], x̄) - return dx -end - -for RNN in (CuRNN, CuGRU) - @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - (y, ho), back = CUDNN.pullback(desc(m), x, h) - (ho, y), function (Δ) - dho, dy = coerce_cuda(Δ) # Support FillArrays etc. - m̄ = back(dy, dho) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) - (dm, unbroadcast(h, m̄.h), m̄.x) - end - end -end - -@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) - ((ho, co), y), function (Δ) - dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. - dho, dco = dhc === nothing ? (nothing, nothing) : dhc - m̄ = back(dy, dho, dco) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) - (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) - end -end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 4c3cc0a612..70ba5e6988 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,3 +1,4 @@ + gate(h, n) = (1:h) .+ h*(n-1) gate(x::AbstractVector, h, n) = @view x[gate(h,n)] gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] @@ -24,21 +25,19 @@ rnn.(1:10) # apply to a sequence rnn.state # 60 ``` """ -mutable struct Recur{T} +mutable struct Recur{T,S} cell::T - init - state + state::S end -Recur(m, h = hidden(m)) = Recur(m, h, h) - function (m::Recur)(xs...) h, y = m.cell(m.state, xs...) m.state = h return y end -@functor Recur cell, init +@functor Recur +trainable(a::Recur) = (a.cell,) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -52,25 +51,23 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) ``` """ -reset!(m::Recur) = (m.state = m.init) +reset!(m::Recur) = (m.state = m.cell.state) reset!(m) = foreach(reset!, functor(m)[1]) flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN -mutable struct RNNCell{F,A,V} +struct RNNCell{F,A,V,S} σ::F Wi::A Wh::A b::V - h::V + state::S end -RNNCell(in::Integer, out::Integer, σ = tanh; - init = glorot_uniform) = - RNNCell(σ, init(out, in), init(out, out), - init(out), zeros(out)) +RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) = + RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -78,8 +75,6 @@ function (m::RNNCell)(h, x) return h, h end -hidden(m::RNNCell) = m.h - @functor RNNCell function Base.show(io::IO, l::RNNCell) @@ -94,22 +89,23 @@ end The most basic recurrent layer; essentially acts as a `Dense` layer, but with the output fed back into the input each time step. """ +Recur(m::RNNCell) = Recur(m, m.state) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM -mutable struct LSTMCell{A,V} +struct LSTMCell{A,V,S} Wi::A Wh::A b::V - h::V - c::V + state::S end function LSTMCell(in::Integer, out::Integer; - init = glorot_uniform) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4), - zeros(out), zeros(out)) + init = glorot_uniform, + initb = zeros, + init_state = zeros) + cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1))) cell.b[gate(out, 2)] .= 1 return cell end @@ -126,8 +122,6 @@ function (m::LSTMCell)((h, c), x) return (h′, c), h′ end -hidden(m::LSTMCell) = (m.h, m.c) - @functor LSTMCell Base.show(io::IO, l::LSTMCell) = @@ -142,20 +136,22 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ +# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), +# (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) +Recur(m::LSTMCell) = Recur(m, m.state) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU -mutable struct GRUCell{A,V} +struct GRUCell{A,V,S} Wi::A Wh::A b::V - h::V + state::S end -GRUCell(in, out; init = glorot_uniform) = - GRUCell(init(out * 3, in), init(out * 3, out), - init(out * 3), zeros(out)) +GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) = + GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1) @@ -167,8 +163,6 @@ function (m::GRUCell)(h, x) return h′, h′ end -hidden(m::GRUCell) = m.h - @functor GRUCell Base.show(io::IO, l::GRUCell) = @@ -183,6 +177,7 @@ RNN but generally exhibits a longer memory span over sequences. See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ +Recur(m::GRUCell) = Recur(m, m.state) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) @adjoint function Broadcast.broadcasted(f::Recur, args...) diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index 345401e0ff..a7e5fa234f 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -8,7 +8,7 @@ using Flux: pullback Flux.reset!(m) θ = gradient(() -> sum(m(x)), params(m)) @test x isa CuArray - @test_broken θ[m.cell.Wi] isa CuArray + @test θ[m.cell.Wi] isa CuArray @test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi]) end @@ -20,8 +20,8 @@ end Flux.reset!(rnn) Flux.reset!(curnn) x = batch_size == 1 ? - rand(10) : - rand(10, batch_size) + rand(Float32, 10) : + rand(Float32, 10, batch_size) cux = gpu(x) y, back = pullback((r, x) -> r(x), rnn, x) @@ -29,8 +29,6 @@ end @test y ≈ collect(cuy) - @test haskey(Flux.CUDAint.descs, curnn.cell) - ȳ = randn(size(y)) m̄, x̄ = back(ȳ) cum̄, cux̄ = cuback(gpu(ȳ)) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 2bb093fc96..0691f5929e 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,6 +1,6 @@ # Ref FluxML/Flux.jl#1209 @testset "BPTT" begin - seq = [rand(2) for i = 1:3] + seq = [rand(Float32, (2,1)) for i = 1:3] for r ∈ [RNN,] rnn = r(2,3) Flux.reset!(rnn) @@ -11,7 +11,7 @@ bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * tanh.(rnn.cell.Wi * seq[2] + Wh * tanh.(rnn.cell.Wi * seq[1] + - Wh * rnn.init + Wh * rnn.cell.state + rnn.cell.b) + rnn.cell.b) + rnn.cell.b)), diff --git a/test/utils.jl b/test/utils.jl index b2deed4adb..cb5c12cc2f 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -101,16 +101,16 @@ end m = Dense(10, 5) @test size.(params(m)) == [(5, 10), (5,)] m = RNN(10, 5) - @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] + @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] # Layer duplicated in same chain, params just once pls. c = Chain(m, m) - @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)] + @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] # Self-referential array. Just want params, no stack overflow pls. r = Any[nothing,m] r[1] = r - @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)] + @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] end @testset "Basic Stacking" begin