From 454bbf6b1342a81dde5284320584ce5bdf2c628a Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Mon, 19 Oct 2020 12:18:58 -0400 Subject: [PATCH 01/17] test RNN --- Manifest.toml | 157 ++++++++++++++++------------- src/Flux.jl | 5 +- src/cuda/cuda.jl | 3 +- src/cuda/curnn_jdb_v1.jl | 98 ++++++++++++++++++ src/layers/recurrent_jdb.jl | 191 ++++++++++++++++++++++++++++++++++++ test/rnn-test-jdb.jl | 79 +++++++++++++++ 6 files changed, 462 insertions(+), 71 deletions(-) create mode 100644 src/cuda/curnn_jdb_v1.jl create mode 100644 src/layers/recurrent_jdb.jl create mode 100644 test/rnn-test-jdb.jl diff --git a/Manifest.toml b/Manifest.toml index 81fb59f7d9..c184b96a11 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -14,15 +14,21 @@ version = "0.3.3" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "0fac443759fa829ed8066db6cf1077d888bb6573" +git-tree-sha1 = "42c42f2221906892ceb765dbcb1a51deeffd86d7" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "2.0.2" +version = "2.3.0" [[ArrayLayouts]] -deps = ["FillArrays", "LinearAlgebra"] -git-tree-sha1 = "951c3fc1ff93497c88fb1dfa893f4de55d0b38e3" +deps = ["Compat", "FillArrays", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "613ee839b41d4fb5aef6494a9ac998d058729051" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.3.8" +version = "0.4.9" + +[[Artifacts]] +deps = ["Pkg"] +git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.3.0" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -40,21 +46,21 @@ version = "0.4.1" [[CUDA]] deps = ["AbstractFFTs", "Adapt", "BinaryProvider", "CEnum", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"] -git-tree-sha1 = "2c007bb3931360013d32bcd53a727e0a7f144e00" +git-tree-sha1 = "83bfd180e2f842f6d4ee315a6db8665e9aa0c19b" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "1.2.1" +version = "1.3.3" [[ChainRules]] deps = ["ChainRulesCore", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "f54401294c1ed68d7cc2a7ea280ec101bca722bc" +git-tree-sha1 = "04afa37b7ea5fe1964ebec15b7f4769276b4edc5" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.12" +version = "0.7.26" [[ChainRulesCore]] -deps = ["MuladdMacro"] -git-tree-sha1 = "971b03f25bdf2acab79f1c51afc717f9dccf43c2" +deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] +git-tree-sha1 = "aebbda0a7c644bd8739b34f2a1b1e48f114aab49" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.5" +version = "0.9.17" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -64,15 +70,15 @@ version = "0.7.0" [[ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "607c0ea16cb32af49ea2976f90c0c5acbca37d21" +git-tree-sha1 = "4bffea7ed1a9f0f3d1a131bbcd4b925548d75288" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.10.8" +version = "0.10.9" [[Colors]] deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"] -git-tree-sha1 = "5639e44833cfcf78c6a73fbceb4da75611d312cd" +git-tree-sha1 = "008d6bc68dea6beb6303fdc37188cb557391ebf2" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.3" +version = "0.12.4" [[CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -80,11 +86,17 @@ git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "f76e41cf110de7176a657c72409e722cfc86fbb6" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.20.0" + [[CompilerSupportLibraries_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70" uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.3+0" +version = "0.3.4+0" [[CpuId]] deps = ["Markdown", "Test"] @@ -98,10 +110,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.3.0" [[DataStructures]] -deps = ["InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "88d48e133e6d3dd68183309877eac74393daa7eb" +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "db07bb22795762895b60e44d62b34b16c982a687" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.17.20" +version = "0.18.7" [[Dates]] deps = ["Printf"] @@ -129,20 +141,20 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "c5714d9bcdba66389612dc4c47ed827c64112997" +git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.2" +version = "0.8.3" [[ExprTools]] -git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95" +git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.1" +version = "0.1.3" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "4863cbb7910079369e258dee4add9d06ead5063a" +git-tree-sha1 = "502b3de6039d5b78c76118423858d981349f3823" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.8.14" +version = "0.9.7" [[FixedPointNumbers]] deps = ["Statistics"] @@ -168,31 +180,36 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "9027d6128645d227fa9ba7d5b3aa00af696b9aaf" +git-tree-sha1 = "da6398282abd2a8c0dc3e55b49d984fcc2c582e5" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "5.0.0" +version = "5.2.1" [[GPUCompiler]] deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "d6db068d8373bb45ac17e14744cc1d3681d0dbc7" +git-tree-sha1 = "05097d81898c527e3bf218bb083ad0ead4378e5f" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.5.5" +version = "0.6.1" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "6875ae3cfcb9a50af80553d5cc825f406e8d13bc" +git-tree-sha1 = "a8d88c05a23b44b4da6cf4fb5659e13ff95e0f47" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.0" +version = "0.4.1" [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JLLWrappers]] +git-tree-sha1 = "7cec881362e5b4e367ff0279dd99a06526d51a55" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.1.2" + [[Juno]] deps = ["Base64", "Logging", "Media", "Profile"] -git-tree-sha1 = "90976c3ab792a98d240d42f9df07420ccfc60668" +git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.8.3" +version = "0.8.4" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] @@ -216,15 +233,15 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[LoopVectorization]] deps = ["DocStringExtensions", "LinearAlgebra", "OffsetArrays", "SIMDPirates", "SLEEFPirates", "UnPack", "VectorizationBase"] -git-tree-sha1 = "5d1d5fb3e9a4c18ae1aa8cb16b0fbdbfeb7db5ed" +git-tree-sha1 = "3242a8f411e19eda9adc49d0b877681975c11375" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.8.22" +version = "0.8.26" [[MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" +git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.5" +version = "0.5.6" [[Markdown]] deps = ["Base64"] @@ -238,9 +255,9 @@ version = "0.5.0" [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5" +git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.3" +version = "0.4.4" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -252,9 +269,9 @@ version = "0.2.2" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "8ec4693a5422f0b064ce324f59351f24aa474893" +git-tree-sha1 = "1ef04283efe283be08e2d0de842f5e5286dd0b7a" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.4" +version = "0.7.5" [[NaNMath]] git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd" @@ -262,20 +279,20 @@ uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "0.3.4" [[OffsetArrays]] -git-tree-sha1 = "2066e16af994955287f2e03ba1d9e890eb43b0dd" +git-tree-sha1 = "a416e2f267e2c8729f25bcaf1ce19d2893faf393" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.1.2" +version = "1.3.1" [[OpenSpecFun_jll]] -deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] -git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+3" +version = "0.5.3+4" [[OrderedCollections]] -git-tree-sha1 = "293b70ac1780f9584c89268a6e2a560d938a7065" +git-tree-sha1 = "16c08bf5dba06609fe45e30860092d6fa41fde7b" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.0" +version = "1.3.1" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -305,18 +322,18 @@ version = "0.2.0" [[Requires]] deps = ["UUIDs"] -git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" +git-tree-sha1 = "28faf1c963ca1dc3ec87f166d92982e3c4a1f66d" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.0.1" +version = "1.1.0" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[SIMDPirates]] deps = ["VectorizationBase"] -git-tree-sha1 = "884df77ee290b45bcf9d2edcdb186afaf1281c39" +git-tree-sha1 = "450d163d3279a1d35e3aad3352a5167ef21b84a4" uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a" -version = "0.8.23" +version = "0.8.25" [[SLEEFPirates]] deps = ["Libdl", "SIMDPirates", "VectorizationBase"] @@ -327,6 +344,10 @@ version = "0.5.5" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -358,9 +379,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "a6102b1f364befdb05746f386b67c6b7e3262c45" +git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.0" +version = "0.33.2" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] @@ -383,36 +404,36 @@ deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[UnPack]] -git-tree-sha1 = "d4bfa022cd30df012700cf380af2141961bb3bfb" +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.1" +version = "1.0.2" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[VectorizationBase]] deps = ["CpuId", "Libdl", "LinearAlgebra"] -git-tree-sha1 = "09a4d4896aac75199b332eb26a8d0a9e4655e34a" +git-tree-sha1 = "03e2fbb479a1ea350398195b6fbf439bae0f8260" uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.12.30" +version = "0.12.33" [[ZipFile]] deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "254975fef2fc526583bb9b7c9420fe66ffe09f2f" +git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.9.2" +version = "0.9.3" [[Zlib_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "d5bba6485811931e4b8958e2d7ca3738273ac468" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "320228915c8debb12cb434c59057290f0834dbf6" uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.11+15" +version = "1.2.11+18" [[Zygote]] deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "0079d92995b0fbcffd5d475d49ec8ca49375c471" +git-tree-sha1 = "b3f792937233abc61d4020a9e89a7f2bcb67333e" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.5.4" +version = "0.5.9" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/src/Flux.jl b/src/Flux.jl index b8160249f1..efb9e57719 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -38,13 +38,14 @@ include("functor.jl") include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") -include("layers/recurrent.jl") +# include("layers/recurrent.jl") +include("layers/recurrent_jdb.jl") include("layers/normalise.jl") include("data/Data.jl") include("losses/Losses.jl") -using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12 +using .Losses # TODO: stop importing Losses in Flux's namespace in v0.12 include("deprecations.jl") diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 7be752a1dd..4cce4b2c65 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,7 +3,8 @@ module CUDAint using ..CUDA using CUDA: CUDNN -include("curnn.jl") +# include("curnn.jl") +include("curnn_jdb_v1.jl") include("cudnn.jl") end diff --git a/src/cuda/curnn_jdb_v1.jl b/src/cuda/curnn_jdb_v1.jl new file mode 100644 index 0000000000..d2aec1195c --- /dev/null +++ b/src/cuda/curnn_jdb_v1.jl @@ -0,0 +1,98 @@ +import ..Flux: Flux, relu + +CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} +CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} +CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} +CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} + +function CUDNN.RNNDesc(m::CuRNNs{T}) where T + if isa(m, CuRNN) + m.σ == tanh ? mode = CUDNN.CUDNN_RNN_TANH : mode = CUDNN.CUDNN_RNN_RELU + h, i = length(m.b), size(m.Wi, 2) + elseif isa(m, CuGRU) + mode = CUDNN.CUDNN_GRU + h, i = length(m.b)÷3, size(m.Wi, 2) + elseif isa(m, CuLSTM) + mode = CUDNN.CUDNN_LSTM + h, i = length(m.b)÷4, size(m.Wi, 2) + println("h: ", h, ", i:", i) + else + error("typeof m ∉ {CuRNN, CuGRU, CuLSTM}") + end + r = CUDNN.RNNDesc{T}(mode, i, h) + return r +end + +const descs = WeakKeyDict() + +function desc(rnn) + d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) + CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) + return d +end + +import Zygote +using Zygote: @adjoint + +function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} + y, h′ = CUDNN.forward(desc(m), x, h) + return h′, y +end + +function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} + y, h′ = CUDNN.forward(desc(m), x, h) + return h′, y +end + +function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} + y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) + return (h′, c′), y +end + +(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) + +trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) + +unbroadcast(x::AbstractArray, Δ) = + size(x) == size(Δ) ? Δ : + length(x) == length(Δ) ? trim(x, Δ) : + trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) + +coerce_cuda(x::Union{CuArray,Nothing}) = x +coerce_cuda(x::Tuple) = coerce_cuda.(x) + +coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0) + +function struct_grad!(cx::Zygote.Context, x, x̄) + for f in fieldnames(typeof(x)) + Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) + end + dx = Zygote.grad_mut(cx, x) + dx[] = Zygote.accum(dx[], x̄) + return dx +end + +for RNN in (CuRNN, CuGRU) + @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} + (y, ho), back = CUDNN.pullback(desc(m), x, h) + (ho, y), function (Δ) + dho, dy = coerce_cuda(Δ) # Support FillArrays etc. + m̄ = back(dy, dho) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) + (dm, unbroadcast(h, m̄.h), m̄.x) + end + end +end + +@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} + (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) + ((ho, co), y), function (Δ) + dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. + dho, dco = dhc === nothing ? (nothing, nothing) : dhc + m̄ = back(dy, dho, dco) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) + (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) + end +end diff --git a/src/layers/recurrent_jdb.jl b/src/layers/recurrent_jdb.jl new file mode 100644 index 0000000000..028821e65f --- /dev/null +++ b/src/layers/recurrent_jdb.jl @@ -0,0 +1,191 @@ + +gate(h, n) = (1:h) .+ h*(n-1) +gate(x::AbstractVector, h, n) = @view x[gate(h,n)] +gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] + +# Stateful recurrence + +""" + Recur(cell) + +`Recur` takes a recurrent cell and makes it stateful, managing the hidden state +in the background. `cell` should be a model of the form: + + h, y = cell(h, x...) + +For example, here's a recurrent network that keeps a running total of its inputs: + +```julia +accum(h, x) = (h + x, x) +rnn = Flux.Recur(accum, 0) +rnn(2) # 2 +rnn(3) # 3 +rnn.state # 5 +rnn.(1:10) # apply to a sequence +rnn.state # 60 +``` +""" +mutable struct Recur{T} + cell::T + init + state +end + +Recur(m, h = hidden(m)) = Recur(m, h, h) + +function (m::Recur)(xs...) + m.state, y = m.cell(m.state, xs...) + return y +end + +@functor Recur cell, init +# @functor Recur cell + +Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") + +""" + reset!(rnn) + +Reset the hidden state of a recurrent layer back to its original value. + +Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: +```julia +rnn.state = hidden(rnn.cell) +``` +""" +reset!(m::Recur) = (m.state = m.init) #modif to broadcast value of init into state +reset!(m) = foreach(reset!, functor(m)[1]) + +flip(f, xs) = reverse(f.(reverse(xs))) + +# Vanilla RNN + +mutable struct RNNCell{F,A,V} + σ::F + Wi::A + Wh::A + b::V + # init::V # change from h - remove - init state learnable parameter from Recur +end + +RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = + RNNCell(σ, init(out, in), init(out, out), init(out)) # remove h/init initialization + +function (m::RNNCell)(h, x) + σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + h = σ.(Wi*x .+ Wh*h .+ b) + return h, h +end + +hidden(m::RNNCell) = m.h + +@functor RNNCell + +function Base.show(io::IO, l::RNNCell) + print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end + +""" + RNN(in::Integer, out::Integer, σ = tanh) + +The most basic recurrent layer; essentially acts as a `Dense` layer, but with the +output fed back into the input each time step. +""" +Recur(m::RNNCell) = Recur(m, zeros(length(m.b)), zeros(length(m.b))) +RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) + +# LSTM + +mutable struct LSTMCell{A,V} + Wi::A + Wh::A + b::V + # h::V + # c::V +end + +function LSTMCell(in::Integer, out::Integer; + init = glorot_uniform) + cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4)) + cell.b[gate(out, 2)] .= 1 + return cell +end + +function (m::LSTMCell)((h, c), x) + b, o = m.b, size(h, 1) + g = m.Wi*x .+ m.Wh*h .+ b + input = σ.(gate(g, o, 1)) + forget = σ.(gate(g, o, 2)) + cell = tanh.(gate(g, o, 3)) + output = σ.(gate(g, o, 4)) + c = forget .* c .+ input .* cell + h′ = output .* tanh.(c) + return (h′, c), h′ +end + +# hidden(m::LSTMCell) = (m.h, m.c) + +@functor LSTMCell + +Base.show(io::IO, l::LSTMCell) = + print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") + +""" + LSTM(in::Integer, out::Integer) + +[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) +recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. + +See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) +for a good overview of the internals. +""" +Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), + (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) +LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) + +# GRU + +mutable struct GRUCell{A,V} + Wi::A + Wh::A + b::V + # h::V +end + +GRUCell(in, out; init = glorot_uniform) = + GRUCell(init(out * 3, in), init(out * 3, out), init(out * 3)) + +function (m::GRUCell)(h, x) + b, o = m.b, size(h, 1) + gx, gh = m.Wi*x, m.Wh*h + r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) + z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) + h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) + h′ = (1 .- z).*h̃ .+ z.*h + return h′, h′ +end + +# hidden(m::GRUCell) = m.h + +@functor GRUCell + +Base.show(io::IO, l::GRUCell) = + print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") + +""" + GRU(in::Integer, out::Integer) + +[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an +RNN but generally exhibits a longer memory span over sequences. + +See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) +for a good overview of the internals. +""" +Recur(m::GRUCell) = Recur(m, zeros(length(m.b)÷3), zeros(length(m.b)÷3)) +GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) + +@adjoint function Broadcast.broadcasted(f::Recur, args...) + Zygote.∇map(__context__, f, args...) +end diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl new file mode 100644 index 0000000000..a9caa93e67 --- /dev/null +++ b/test/rnn-test-jdb.jl @@ -0,0 +1,79 @@ +using Revise +using Flux +# using CUDA +using Statistics: mean + +###################### +# basic test +###################### +seq = [[1,2] ./ 10 for b in 1:3] +seq = hcat(seq...) +seq = [seq for i = 1:4] + +m = RNN(2, 5) +m.cell.Wi .= [0.1 0] +m.cell.Wh .= [0.2] +m.cell.b .= 1.1 +# m.cell.h .= 1.2 +m.init .= 1.3 # init value stays at init value - rather than taking state value +m.state .= 1.4 + +params(m) +params(m)[1] +m(seq[2]) +@time m.(seq) +@time map(m, seq) + +###################### +# single array +###################### +seq = [[1,2] ./ 10 for b in 1:3] +seq = hcat(seq...) +seq = [seq for i = 1:4] +seq = cat(seq..., dims=3) + +m = RNN(2, 5) +m.cell.Wi .= [0.1 0] +m.cell.Wh .= [0.1] +m.cell.b .= 0 +# m.cell.h .= 0 +m.init .= 0.0 +m.state .= 0 + +params(m) +@time mapslices(m, seq, dims=(1,2)) +mapslices(size, seq, dims=(1,2)) + + + +###################### +# issue: https://github.com/FluxML/Flux.jl/issues/1114 +###################### +rnn = Chain(LSTM(16, 8), + Dense(8,1, σ), + x -> reshape(x,:)) + +X = [rand(16,10) for i in 1:20] +Y = rand(10,20) ./ 10 + +rnn = rnn |> gpu +X = gpu(X) +Y = gpu(Y) + +θ = Flux.params(rnn) +loss(x,y) = mean((Flux.stack(rnn.(x),2) .- y) .^ 2f0) +opt = ADAM(1e-3) +size(rnn[1].state[1]) +Flux.reset!(rnn) +size(rnn[1].state[1]) +Flux.train!(loss, θ, [(X,Y)], opt) +size(rnn[1].state[1]) +loss(X,Y) + +Flux.stack(rnn.(X),2) +rnn.(X) + +using CUDA + +x1 = LSTM(16,8) +CUDA.CUDNN.RNNDesc(x1) From f93a11471357b372df5477f8f0f9528f48ecadbc Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Tue, 20 Oct 2020 18:45:58 -0400 Subject: [PATCH 02/17] tests --- src/layers/recurrent_jdb.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/recurrent_jdb.jl b/src/layers/recurrent_jdb.jl index 028821e65f..1bb9616ba8 100644 --- a/src/layers/recurrent_jdb.jl +++ b/src/layers/recurrent_jdb.jl @@ -65,7 +65,7 @@ mutable struct RNNCell{F,A,V} Wi::A Wh::A b::V - # init::V # change from h - remove - init state learnable parameter from Recur + # h::V # remove - init state learnable parameter from Recur end RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = From 4ddf19d007ce3a6eba1bc5139ae4695a9ff5a603 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Tue, 20 Oct 2020 20:27:35 -0400 Subject: [PATCH 03/17] drop CUDNN rnn --- src/Flux.jl | 3 +- src/cuda/cuda.jl | 6 +- src/cuda/curnn.jl | 184 +++++++++++++++++----------------- src/cuda/curnn_jdb_v1.jl | 98 ------------------ src/layers/recurrent.jl | 30 +++--- src/layers/recurrent_jdb.jl | 191 ------------------------------------ test/rnn-test-jdb.jl | 89 +++++++++++++++++ 7 files changed, 201 insertions(+), 400 deletions(-) delete mode 100644 src/cuda/curnn_jdb_v1.jl delete mode 100644 src/layers/recurrent_jdb.jl diff --git a/src/Flux.jl b/src/Flux.jl index efb9e57719..850eb89f1b 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -38,8 +38,7 @@ include("functor.jl") include("layers/stateless.jl") include("layers/basic.jl") include("layers/conv.jl") -# include("layers/recurrent.jl") -include("layers/recurrent_jdb.jl") +include("layers/recurrent.jl") include("layers/normalise.jl") include("data/Data.jl") diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 4cce4b2c65..cbb738e678 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,8 +3,12 @@ module CUDAint using ..CUDA using CUDA: CUDNN + +import ..Flux: Flux +import Zygote +using Zygote: @adjoint + # include("curnn.jl") -include("curnn_jdb_v1.jl") include("cudnn.jl") end diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index f4f9cb4f97..0b82b2ed58 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -1,89 +1,95 @@ -import ..Flux: Flux, relu - -CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} -CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} -CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} -CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} - -function CUDNN.RNNDesc(m::CuRNNs{T}) where T - h, i = length(m.h), size(m.Wi, 2) - mode = m isa CuRNN ? - (m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) : - m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM - r = CUDNN.RNNDesc{T}(mode, i, h) - return r -end - -const descs = WeakKeyDict() - -function desc(rnn) - d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) - CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) - return d -end - -import Zygote -using Zygote: @adjoint - -function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = CUDNN.forward(desc(m), x, h) - return h′, y -end - -function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = CUDNN.forward(desc(m), x, h) - return h′, y -end - -function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) - return (h′, c′), y -end - -(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) - -coerce_cuda(x::Union{CuArray,Nothing}) = x -coerce_cuda(x::Tuple) = coerce_cuda.(x) - -coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0) - -function struct_grad!(cx::Zygote.Context, x, x̄) - for f in fieldnames(typeof(x)) - Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) - end - dx = Zygote.grad_mut(cx, x) - dx[] = Zygote.accum(dx[], x̄) - return dx -end - -for RNN in (CuRNN, CuGRU) - @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - (y, ho), back = CUDNN.pullback(desc(m), x, h) - (ho, y), function (Δ) - dho, dy = coerce_cuda(Δ) # Support FillArrays etc. - m̄ = back(dy, dho) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) - (dm, unbroadcast(h, m̄.h), m̄.x) - end - end -end - -@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) - ((ho, co), y), function (Δ) - dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. - dho, dco = dhc === nothing ? (nothing, nothing) : dhc - m̄ = back(dy, dho, dco) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) - (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) - end -end +# import ..Flux: relu +# +# CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} +# CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} +# CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} +# CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} +# +# function CUDNN.RNNDesc(m::CuRNNs{T}) where T +# if isa(m, CuRNN) +# m.σ == tanh ? mode = CUDNN.CUDNN_RNN_TANH : mode = CUDNN.CUDNN_RNN_RELU +# h, i = length(m.b), size(m.Wi, 2) +# elseif isa(m, CuGRU) +# mode = CUDNN.CUDNN_GRU +# h, i = length(m.b)÷3, size(m.Wi, 2) +# elseif isa(m, CuLSTM) +# mode = CUDNN.CUDNN_LSTM +# h, i = length(m.b)÷4, size(m.Wi, 2) +# println("h: ", h, ", i:", i) +# else +# error("typeof m ∉ {CuRNN, CuGRU, CuLSTM}") +# end +# r = CUDNN.RNNDesc{T}(mode, i, h) +# return r +# end +# +# const descs = WeakKeyDict() +# +# function desc(rnn) +# d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) +# CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) +# return d +# end +# +# function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} +# y, h′ = CUDNN.forward(desc(m), x, h) +# return h′, y +# end +# +# function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} +# y, h′ = CUDNN.forward(desc(m), x, h) +# return h′, y +# end +# +# function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} +# y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) +# return (h′, c′), y +# end +# +# (m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +# (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +# (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +# +# trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) +# +# unbroadcast(x::AbstractArray, Δ) = +# size(x) == size(Δ) ? Δ : +# length(x) == length(Δ) ? trim(x, Δ) : +# trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) +# +# coerce_cuda(x::Union{CuArray,Nothing}) = x +# coerce_cuda(x::Tuple) = coerce_cuda.(x) +# +# coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0) +# +# function struct_grad!(cx::Zygote.Context, x, x̄) +# for f in fieldnames(typeof(x)) +# Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) +# end +# dx = Zygote.grad_mut(cx, x) +# dx[] = Zygote.accum(dx[], x̄) +# return dx +# end +# +# for RNN in (CuRNN, CuGRU) +# @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} +# (y, ho), back = CUDNN.pullback(desc(m), x, h) +# (ho, y), function (Δ) +# dho, dy = coerce_cuda(Δ) # Support FillArrays etc. +# m̄ = back(dy, dho) +# dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) +# (dm, unbroadcast(h, m̄.h), m̄.x) +# end +# end +# end +# +# @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} +# (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) +# ((ho, co), y), function (Δ) +# dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. +# dho, dco = dhc === nothing ? (nothing, nothing) : dhc +# m̄ = back(dy, dho, dco) +# dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) +# (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) +# end +# end diff --git a/src/cuda/curnn_jdb_v1.jl b/src/cuda/curnn_jdb_v1.jl deleted file mode 100644 index d2aec1195c..0000000000 --- a/src/cuda/curnn_jdb_v1.jl +++ /dev/null @@ -1,98 +0,0 @@ -import ..Flux: Flux, relu - -CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} -CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} -CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} -CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} - -function CUDNN.RNNDesc(m::CuRNNs{T}) where T - if isa(m, CuRNN) - m.σ == tanh ? mode = CUDNN.CUDNN_RNN_TANH : mode = CUDNN.CUDNN_RNN_RELU - h, i = length(m.b), size(m.Wi, 2) - elseif isa(m, CuGRU) - mode = CUDNN.CUDNN_GRU - h, i = length(m.b)÷3, size(m.Wi, 2) - elseif isa(m, CuLSTM) - mode = CUDNN.CUDNN_LSTM - h, i = length(m.b)÷4, size(m.Wi, 2) - println("h: ", h, ", i:", i) - else - error("typeof m ∉ {CuRNN, CuGRU, CuLSTM}") - end - r = CUDNN.RNNDesc{T}(mode, i, h) - return r -end - -const descs = WeakKeyDict() - -function desc(rnn) - d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) - CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) - return d -end - -import Zygote -using Zygote: @adjoint - -function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = CUDNN.forward(desc(m), x, h) - return h′, y -end - -function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′ = CUDNN.forward(desc(m), x, h) - return h′, y -end - -function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) - return (h′, c′), y -end - -(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) - -coerce_cuda(x::Union{CuArray,Nothing}) = x -coerce_cuda(x::Tuple) = coerce_cuda.(x) - -coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0) - -function struct_grad!(cx::Zygote.Context, x, x̄) - for f in fieldnames(typeof(x)) - Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) - end - dx = Zygote.grad_mut(cx, x) - dx[] = Zygote.accum(dx[], x̄) - return dx -end - -for RNN in (CuRNN, CuGRU) - @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} - (y, ho), back = CUDNN.pullback(desc(m), x, h) - (ho, y), function (Δ) - dho, dy = coerce_cuda(Δ) # Support FillArrays etc. - m̄ = back(dy, dho) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) - (dm, unbroadcast(h, m̄.h), m̄.x) - end - end -end - -@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} - (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) - ((ho, co), y), function (Δ) - dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. - dho, dco = dhc === nothing ? (nothing, nothing) : dhc - m̄ = back(dy, dho, dco) - dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) - (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) - end -end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 4c3cc0a612..44a127c282 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,3 +1,4 @@ + gate(h, n) = (1:h) .+ h*(n-1) gate(x::AbstractVector, h, n) = @view x[gate(h,n)] gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] @@ -33,8 +34,7 @@ end Recur(m, h = hidden(m)) = Recur(m, h, h) function (m::Recur)(xs...) - h, y = m.cell(m.state, xs...) - m.state = h + m.state, y = m.cell(m.state, xs...) return y end @@ -52,7 +52,7 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) ``` """ -reset!(m::Recur) = (m.state = m.init) +reset!(m::Recur) = (m.state = m.init) #modif to broadcast value of init into state reset!(m) = foreach(reset!, functor(m)[1]) flip(f, xs) = reverse(f.(reverse(xs))) @@ -64,13 +64,10 @@ mutable struct RNNCell{F,A,V} Wi::A Wh::A b::V - h::V end -RNNCell(in::Integer, out::Integer, σ = tanh; - init = glorot_uniform) = - RNNCell(σ, init(out, in), init(out, out), - init(out), zeros(out)) +RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = + RNNCell(σ, init(out, in), init(out, out), init(out)) # remove h/init initialization function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -94,6 +91,7 @@ end The most basic recurrent layer; essentially acts as a `Dense` layer, but with the output fed back into the input each time step. """ +Recur(m::RNNCell) = Recur(m, zeros(length(m.b)), zeros(length(m.b))) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM @@ -102,14 +100,11 @@ mutable struct LSTMCell{A,V} Wi::A Wh::A b::V - h::V - c::V end function LSTMCell(in::Integer, out::Integer; init = glorot_uniform) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4), - zeros(out), zeros(out)) + cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4)) cell.b[gate(out, 2)] .= 1 return cell end @@ -126,8 +121,6 @@ function (m::LSTMCell)((h, c), x) return (h′, c), h′ end -hidden(m::LSTMCell) = (m.h, m.c) - @functor LSTMCell Base.show(io::IO, l::LSTMCell) = @@ -142,6 +135,8 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ +Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), + (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU @@ -150,12 +145,10 @@ mutable struct GRUCell{A,V} Wi::A Wh::A b::V - h::V end GRUCell(in, out; init = glorot_uniform) = - GRUCell(init(out * 3, in), init(out * 3, out), - init(out * 3), zeros(out)) + GRUCell(init(out * 3, in), init(out * 3, out), init(out * 3)) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1) @@ -167,8 +160,6 @@ function (m::GRUCell)(h, x) return h′, h′ end -hidden(m::GRUCell) = m.h - @functor GRUCell Base.show(io::IO, l::GRUCell) = @@ -183,6 +174,7 @@ RNN but generally exhibits a longer memory span over sequences. See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ +Recur(m::GRUCell) = Recur(m, zeros(length(m.b)÷3), zeros(length(m.b)÷3)) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) @adjoint function Broadcast.broadcasted(f::Recur, args...) diff --git a/src/layers/recurrent_jdb.jl b/src/layers/recurrent_jdb.jl deleted file mode 100644 index 1bb9616ba8..0000000000 --- a/src/layers/recurrent_jdb.jl +++ /dev/null @@ -1,191 +0,0 @@ - -gate(h, n) = (1:h) .+ h*(n-1) -gate(x::AbstractVector, h, n) = @view x[gate(h,n)] -gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] - -# Stateful recurrence - -""" - Recur(cell) - -`Recur` takes a recurrent cell and makes it stateful, managing the hidden state -in the background. `cell` should be a model of the form: - - h, y = cell(h, x...) - -For example, here's a recurrent network that keeps a running total of its inputs: - -```julia -accum(h, x) = (h + x, x) -rnn = Flux.Recur(accum, 0) -rnn(2) # 2 -rnn(3) # 3 -rnn.state # 5 -rnn.(1:10) # apply to a sequence -rnn.state # 60 -``` -""" -mutable struct Recur{T} - cell::T - init - state -end - -Recur(m, h = hidden(m)) = Recur(m, h, h) - -function (m::Recur)(xs...) - m.state, y = m.cell(m.state, xs...) - return y -end - -@functor Recur cell, init -# @functor Recur cell - -Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") - -""" - reset!(rnn) - -Reset the hidden state of a recurrent layer back to its original value. - -Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: -```julia -rnn.state = hidden(rnn.cell) -``` -""" -reset!(m::Recur) = (m.state = m.init) #modif to broadcast value of init into state -reset!(m) = foreach(reset!, functor(m)[1]) - -flip(f, xs) = reverse(f.(reverse(xs))) - -# Vanilla RNN - -mutable struct RNNCell{F,A,V} - σ::F - Wi::A - Wh::A - b::V - # h::V # remove - init state learnable parameter from Recur -end - -RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = - RNNCell(σ, init(out, in), init(out, out), init(out)) # remove h/init initialization - -function (m::RNNCell)(h, x) - σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b - h = σ.(Wi*x .+ Wh*h .+ b) - return h, h -end - -hidden(m::RNNCell) = m.h - -@functor RNNCell - -function Base.show(io::IO, l::RNNCell) - print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) - l.σ == identity || print(io, ", ", l.σ) - print(io, ")") -end - -""" - RNN(in::Integer, out::Integer, σ = tanh) - -The most basic recurrent layer; essentially acts as a `Dense` layer, but with the -output fed back into the input each time step. -""" -Recur(m::RNNCell) = Recur(m, zeros(length(m.b)), zeros(length(m.b))) -RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) - -# LSTM - -mutable struct LSTMCell{A,V} - Wi::A - Wh::A - b::V - # h::V - # c::V -end - -function LSTMCell(in::Integer, out::Integer; - init = glorot_uniform) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4)) - cell.b[gate(out, 2)] .= 1 - return cell -end - -function (m::LSTMCell)((h, c), x) - b, o = m.b, size(h, 1) - g = m.Wi*x .+ m.Wh*h .+ b - input = σ.(gate(g, o, 1)) - forget = σ.(gate(g, o, 2)) - cell = tanh.(gate(g, o, 3)) - output = σ.(gate(g, o, 4)) - c = forget .* c .+ input .* cell - h′ = output .* tanh.(c) - return (h′, c), h′ -end - -# hidden(m::LSTMCell) = (m.h, m.c) - -@functor LSTMCell - -Base.show(io::IO, l::LSTMCell) = - print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷4, ")") - -""" - LSTM(in::Integer, out::Integer) - -[Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) -recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. - -See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for a good overview of the internals. -""" -Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), - (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) -LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) - -# GRU - -mutable struct GRUCell{A,V} - Wi::A - Wh::A - b::V - # h::V -end - -GRUCell(in, out; init = glorot_uniform) = - GRUCell(init(out * 3, in), init(out * 3, out), init(out * 3)) - -function (m::GRUCell)(h, x) - b, o = m.b, size(h, 1) - gx, gh = m.Wi*x, m.Wh*h - r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) - z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) - h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) - h′ = (1 .- z).*h̃ .+ z.*h - return h′, h′ -end - -# hidden(m::GRUCell) = m.h - -@functor GRUCell - -Base.show(io::IO, l::GRUCell) = - print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")") - -""" - GRU(in::Integer, out::Integer) - -[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an -RNN but generally exhibits a longer memory span over sequences. - -See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) -for a good overview of the internals. -""" -Recur(m::GRUCell) = Recur(m, zeros(length(m.b)÷3), zeros(length(m.b)÷3)) -GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) - -@adjoint function Broadcast.broadcasted(f::Recur, args...) - Zygote.∇map(__context__, f, args...) -end diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl index a9caa93e67..94035640d7 100644 --- a/test/rnn-test-jdb.jl +++ b/test/rnn-test-jdb.jl @@ -77,3 +77,92 @@ using CUDA x1 = LSTM(16,8) CUDA.CUDNN.RNNDesc(x1) + + +######################## +# rnn test gpu +######################## +feat = 2 +h_size = 11 +seq_len = 4 +batch_size = 3 +rnn = Chain(RNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x,:)) + +X = [rand(feat, batch_size) for i in 1:seq_len] +Y = rand(batch_size, seq_len) ./ 10 + +rnn = rnn |> gpu +X = gpu(X) +Y = gpu(Y) + +θ = Flux.params(rnn) +mapreduce(length, +, θ) - h_size -1 # num params in RNN + +function loss(x,y) + l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) + Flux.reset!(rnn) + return l +end + +opt = ADAM(1e-3) +loss(X,Y) +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X,Y)], opt) +loss(X,Y) +for i in 1:100 + Flux.train!(loss, θ, [(X,Y)], opt) +end +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X,Y)], opt) + +θ[1] +θ[3] +θ[4] + + +######################## +# LSTM test gpu +######################## +feat = 32 +h_size = 64 +seq_len = 10 +batch_size = 32 + +rnn = Chain(RNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x,:)) + +X = [rand(feat, batch_size) for i in 1:seq_len] +Y = rand(batch_size, seq_len) ./ 10 + +rnn = rnn |> gpu +X = gpu(X) +Y = gpu(Y) + +θ = Flux.params(rnn) +function loss(x,y) + l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) + Flux.reset!(rnn) + return l +end + +opt = ADAM(1e-4) +loss(X,Y) +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X,Y)], opt) +loss(X,Y) +for i in 1:100 + Flux.train!(loss, θ, [(X,Y)], opt) +end +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X,Y)], opt) + +θ[1] +θ[2] +θ[3] +θ[4] +θ[5] +θ[6] +θ[7] From c6bb21e5b9ed434c5a0b134563e7c376193b86ce Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Tue, 20 Oct 2020 20:47:37 -0400 Subject: [PATCH 04/17] fix tests --- src/cuda/cuda.jl | 1 - test/rnn-test-jdb.jl | 95 +++++++++++++++++++++++--------------------- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index cbb738e678..9865ddd500 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -1,7 +1,6 @@ module CUDAint using ..CUDA - using CUDA: CUDNN import ..Flux: Flux diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl index 94035640d7..2cd46633b0 100644 --- a/test/rnn-test-jdb.jl +++ b/test/rnn-test-jdb.jl @@ -3,6 +3,55 @@ using Flux # using CUDA using Statistics: mean +######################## +# RNN test gpu +######################## + +# illustrate diverging behavior of GPU execution +feat = 32 +h_size = 64 +seq_len = 10 +batch_size = 32 + +rnn = Chain(RNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x,:)) + +X = [rand(feat, batch_size) for i in 1:seq_len] +Y = rand(batch_size, seq_len) ./ 10 + +###################################### +#### do not run for cpu execution #### +rnn = rnn |> gpu +X = gpu(X) +Y = gpu(Y) +###################################### + +θ = Flux.params(rnn) +function loss(x,y) + l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) + Flux.reset!(rnn) + return l +end + +opt = ADAM(1e-4) +loss(X,Y) +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X,Y)], opt) +loss(X,Y) +for i in 1:100 + Flux.train!(loss, θ, [(X,Y)], opt) +end +# loss should show an NA if run on GPU, but not on CPU after going over 100 iterations +loss(X,Y) +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X,Y)], opt) + +θ[1] +θ[2] +θ[3] +θ[4] + ###################### # basic test ###################### @@ -120,49 +169,3 @@ Flux.train!(loss, θ, [(X,Y)], opt) θ[1] θ[3] θ[4] - - -######################## -# LSTM test gpu -######################## -feat = 32 -h_size = 64 -seq_len = 10 -batch_size = 32 - -rnn = Chain(RNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x,:)) - -X = [rand(feat, batch_size) for i in 1:seq_len] -Y = rand(batch_size, seq_len) ./ 10 - -rnn = rnn |> gpu -X = gpu(X) -Y = gpu(Y) - -θ = Flux.params(rnn) -function loss(x,y) - l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) - Flux.reset!(rnn) - return l -end - -opt = ADAM(1e-4) -loss(X,Y) -Flux.reset!(rnn) -Flux.train!(loss, θ, [(X,Y)], opt) -loss(X,Y) -for i in 1:100 - Flux.train!(loss, θ, [(X,Y)], opt) -end -Flux.reset!(rnn) -Flux.train!(loss, θ, [(X,Y)], opt) - -θ[1] -θ[2] -θ[3] -θ[4] -θ[5] -θ[6] -θ[7] From f33f34d5c9aa95be252799ce202f1bd68469ac09 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Wed, 21 Oct 2020 02:46:21 -0400 Subject: [PATCH 05/17] cleanup --- src/cuda/cuda.jl | 1 - src/cuda/curnn.jl | 95 -------------------------------------------- test/rnn-test-jdb.jl | 47 ++++++++++++++-------- 3 files changed, 30 insertions(+), 113 deletions(-) delete mode 100644 src/cuda/curnn.jl diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 9865ddd500..6c8096f978 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -7,7 +7,6 @@ import ..Flux: Flux import Zygote using Zygote: @adjoint -# include("curnn.jl") include("cudnn.jl") end diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl deleted file mode 100644 index 0b82b2ed58..0000000000 --- a/src/cuda/curnn.jl +++ /dev/null @@ -1,95 +0,0 @@ -# import ..Flux: relu -# -# CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} -# CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} -# CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} -# CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} -# -# function CUDNN.RNNDesc(m::CuRNNs{T}) where T -# if isa(m, CuRNN) -# m.σ == tanh ? mode = CUDNN.CUDNN_RNN_TANH : mode = CUDNN.CUDNN_RNN_RELU -# h, i = length(m.b), size(m.Wi, 2) -# elseif isa(m, CuGRU) -# mode = CUDNN.CUDNN_GRU -# h, i = length(m.b)÷3, size(m.Wi, 2) -# elseif isa(m, CuLSTM) -# mode = CUDNN.CUDNN_LSTM -# h, i = length(m.b)÷4, size(m.Wi, 2) -# println("h: ", h, ", i:", i) -# else -# error("typeof m ∉ {CuRNN, CuGRU, CuLSTM}") -# end -# r = CUDNN.RNNDesc{T}(mode, i, h) -# return r -# end -# -# const descs = WeakKeyDict() -# -# function desc(rnn) -# d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) -# CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b) -# return d -# end -# -# function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} -# y, h′ = CUDNN.forward(desc(m), x, h) -# return h′, y -# end -# -# function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} -# y, h′ = CUDNN.forward(desc(m), x, h) -# return h′, y -# end -# -# function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} -# y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2]) -# return (h′, c′), y -# end -# -# (m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -# (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -# (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -# -# trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -# -# unbroadcast(x::AbstractArray, Δ) = -# size(x) == size(Δ) ? Δ : -# length(x) == length(Δ) ? trim(x, Δ) : -# trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) -# -# coerce_cuda(x::Union{CuArray,Nothing}) = x -# coerce_cuda(x::Tuple) = coerce_cuda.(x) -# -# coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0) -# -# function struct_grad!(cx::Zygote.Context, x, x̄) -# for f in fieldnames(typeof(x)) -# Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) -# end -# dx = Zygote.grad_mut(cx, x) -# dx[] = Zygote.accum(dx[], x̄) -# return dx -# end -# -# for RNN in (CuRNN, CuGRU) -# @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} -# (y, ho), back = CUDNN.pullback(desc(m), x, h) -# (ho, y), function (Δ) -# dho, dy = coerce_cuda(Δ) # Support FillArrays etc. -# m̄ = back(dy, dho) -# dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing)) -# (dm, unbroadcast(h, m̄.h), m̄.x) -# end -# end -# end -# -# @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} -# (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c) -# ((ho, co), y), function (Δ) -# dhc, dy = coerce_cuda(Δ) # Support FillArrays etc. -# dho, dco = dhc === nothing ? (nothing, nothing) : dhc -# m̄ = back(dy, dho, dco) -# dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing)) -# (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x) -# end -# end diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl index 2cd46633b0..ca0e6da67f 100644 --- a/test/rnn-test-jdb.jl +++ b/test/rnn-test-jdb.jl @@ -8,10 +8,10 @@ using Statistics: mean ######################## # illustrate diverging behavior of GPU execution -feat = 32 -h_size = 64 +feat = 16 +h_size = 32 seq_len = 10 -batch_size = 32 +batch_size = 100 rnn = Chain(RNN(feat, h_size), Dense(h_size, 1, σ), @@ -22,31 +22,50 @@ Y = rand(batch_size, seq_len) ./ 10 ###################################### #### do not run for cpu execution #### -rnn = rnn |> gpu -X = gpu(X) -Y = gpu(Y) +rnn_gpu = rnn |> gpu +X_gpu = gpu(X) +Y_gpu = gpu(Y) ###################################### θ = Flux.params(rnn) +θ_gpu = Flux.params(rnn_gpu) function loss(x,y) l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) Flux.reset!(rnn) return l end +function loss_gpu(x,y) + l = mean((Flux.stack(map(rnn_gpu, x),2) .- y) .^ 2f0) + Flux.reset!(rnn_gpu) + return l +end opt = ADAM(1e-4) -loss(X,Y) -Flux.reset!(rnn) -Flux.train!(loss, θ, [(X,Y)], opt) -loss(X,Y) -for i in 1:100 +opt_gpu = ADAM(1e-4) +# loss(X,Y) +# loss_gpu(X_gpu, Y_gpu) +# Flux.train!(loss, θ, [(X,Y)], opt) +# Flux.train!(loss_gpu, θ_gpu, [(X_gpu,Y_gpu)], opt_gpu) +# loss(X,Y) +# loss_gpu(X_gpu, Y_gpu) +for i in 1:50 + println("iter: ", i) Flux.train!(loss, θ, [(X,Y)], opt) + Flux.train!(loss_gpu, θ_gpu, [(X_gpu,Y_gpu)], opt_gpu) + println("loss_cpu: ", loss(X, Y)) + println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) end # loss should show an NA if run on GPU, but not on CPU after going over 100 iterations loss(X,Y) +loss_gpu(X_gpu, Y_gpu) Flux.reset!(rnn) Flux.train!(loss, θ, [(X,Y)], opt) +rnn.layers[1].init +rnn.layers[1].state +rnn_gpu.layers[1].init +rnn_gpu.layers[1].state + θ[1] θ[2] θ[3] @@ -122,12 +141,6 @@ loss(X,Y) Flux.stack(rnn.(X),2) rnn.(X) -using CUDA - -x1 = LSTM(16,8) -CUDA.CUDNN.RNNDesc(x1) - - ######################## # rnn test gpu ######################## From 829307e7b48c0345e4ad079484e3e1965fe39a66 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Wed, 21 Oct 2020 03:16:25 -0400 Subject: [PATCH 06/17] original manifest --- Manifest.toml | 157 ++++++++++++++++++++++---------------------------- 1 file changed, 68 insertions(+), 89 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index c184b96a11..81fb59f7d9 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -14,21 +14,15 @@ version = "0.3.3" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "42c42f2221906892ceb765dbcb1a51deeffd86d7" +git-tree-sha1 = "0fac443759fa829ed8066db6cf1077d888bb6573" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "2.3.0" +version = "2.0.2" [[ArrayLayouts]] -deps = ["Compat", "FillArrays", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "613ee839b41d4fb5aef6494a9ac998d058729051" +deps = ["FillArrays", "LinearAlgebra"] +git-tree-sha1 = "951c3fc1ff93497c88fb1dfa893f4de55d0b38e3" uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.4.9" - -[[Artifacts]] -deps = ["Pkg"] -git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -version = "1.3.0" +version = "0.3.8" [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -46,21 +40,21 @@ version = "0.4.1" [[CUDA]] deps = ["AbstractFFTs", "Adapt", "BinaryProvider", "CEnum", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "Libdl", "LinearAlgebra", "Logging", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"] -git-tree-sha1 = "83bfd180e2f842f6d4ee315a6db8665e9aa0c19b" +git-tree-sha1 = "2c007bb3931360013d32bcd53a727e0a7f144e00" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "1.3.3" +version = "1.2.1" [[ChainRules]] deps = ["ChainRulesCore", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "04afa37b7ea5fe1964ebec15b7f4769276b4edc5" +git-tree-sha1 = "f54401294c1ed68d7cc2a7ea280ec101bca722bc" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.26" +version = "0.7.12" [[ChainRulesCore]] -deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] -git-tree-sha1 = "aebbda0a7c644bd8739b34f2a1b1e48f114aab49" +deps = ["MuladdMacro"] +git-tree-sha1 = "971b03f25bdf2acab79f1c51afc717f9dccf43c2" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.17" +version = "0.9.5" [[CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -70,15 +64,15 @@ version = "0.7.0" [[ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "4bffea7ed1a9f0f3d1a131bbcd4b925548d75288" +git-tree-sha1 = "607c0ea16cb32af49ea2976f90c0c5acbca37d21" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.10.9" +version = "0.10.8" [[Colors]] deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"] -git-tree-sha1 = "008d6bc68dea6beb6303fdc37188cb557391ebf2" +git-tree-sha1 = "5639e44833cfcf78c6a73fbceb4da75611d312cd" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.4" +version = "0.12.3" [[CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -86,17 +80,11 @@ git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" version = "0.3.0" -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "f76e41cf110de7176a657c72409e722cfc86fbb6" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "3.20.0" - [[CompilerSupportLibraries_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70" +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.4+0" +version = "0.3.3+0" [[CpuId]] deps = ["Markdown", "Test"] @@ -110,10 +98,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.3.0" [[DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "db07bb22795762895b60e44d62b34b16c982a687" +deps = ["InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "88d48e133e6d3dd68183309877eac74393daa7eb" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.7" +version = "0.17.20" [[Dates]] deps = ["Printf"] @@ -141,20 +129,20 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1" +git-tree-sha1 = "c5714d9bcdba66389612dc4c47ed827c64112997" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.3" +version = "0.8.2" [[ExprTools]] -git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" +git-tree-sha1 = "6f0517056812fd6aa3af23d4b70d5325a2ae4e95" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.3" +version = "0.1.1" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "502b3de6039d5b78c76118423858d981349f3823" +git-tree-sha1 = "4863cbb7910079369e258dee4add9d06ead5063a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.9.7" +version = "0.8.14" [[FixedPointNumbers]] deps = ["Statistics"] @@ -180,36 +168,31 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[GPUArrays]] deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"] -git-tree-sha1 = "da6398282abd2a8c0dc3e55b49d984fcc2c582e5" +git-tree-sha1 = "9027d6128645d227fa9ba7d5b3aa00af696b9aaf" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "5.2.1" +version = "5.0.0" [[GPUCompiler]] deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "05097d81898c527e3bf218bb083ad0ead4378e5f" +git-tree-sha1 = "d6db068d8373bb45ac17e14744cc1d3681d0dbc7" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.6.1" +version = "0.5.5" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "a8d88c05a23b44b4da6cf4fb5659e13ff95e0f47" +git-tree-sha1 = "6875ae3cfcb9a50af80553d5cc825f406e8d13bc" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.1" +version = "0.4.0" [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[JLLWrappers]] -git-tree-sha1 = "7cec881362e5b4e367ff0279dd99a06526d51a55" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.1.2" - [[Juno]] deps = ["Base64", "Logging", "Media", "Profile"] -git-tree-sha1 = "07cb43290a840908a771552911a6274bc6c072c7" +git-tree-sha1 = "90976c3ab792a98d240d42f9df07420ccfc60668" uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.8.4" +version = "0.8.3" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] @@ -233,15 +216,15 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[LoopVectorization]] deps = ["DocStringExtensions", "LinearAlgebra", "OffsetArrays", "SIMDPirates", "SLEEFPirates", "UnPack", "VectorizationBase"] -git-tree-sha1 = "3242a8f411e19eda9adc49d0b877681975c11375" +git-tree-sha1 = "5d1d5fb3e9a4c18ae1aa8cb16b0fbdbfeb7db5ed" uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.8.26" +version = "0.8.22" [[MacroTools]] deps = ["Markdown", "Random"] -git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" +git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.6" +version = "0.5.5" [[Markdown]] deps = ["Base64"] @@ -255,9 +238,9 @@ version = "0.5.0" [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" +git-tree-sha1 = "de0a5ce9e5289f27df672ffabef4d1e5861247d5" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.4" +version = "0.4.3" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -269,9 +252,9 @@ version = "0.2.2" [[NNlib]] deps = ["Libdl", "LinearAlgebra", "Pkg", "Requires", "Statistics"] -git-tree-sha1 = "1ef04283efe283be08e2d0de842f5e5286dd0b7a" +git-tree-sha1 = "8ec4693a5422f0b064ce324f59351f24aa474893" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.7.5" +version = "0.7.4" [[NaNMath]] git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd" @@ -279,20 +262,20 @@ uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" version = "0.3.4" [[OffsetArrays]] -git-tree-sha1 = "a416e2f267e2c8729f25bcaf1ce19d2893faf393" +git-tree-sha1 = "2066e16af994955287f2e03ba1d9e890eb43b0dd" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.3.1" +version = "1.1.2" [[OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] +git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+4" +version = "0.5.3+3" [[OrderedCollections]] -git-tree-sha1 = "16c08bf5dba06609fe45e30860092d6fa41fde7b" +git-tree-sha1 = "293b70ac1780f9584c89268a6e2a560d938a7065" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.1" +version = "1.3.0" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -322,18 +305,18 @@ version = "0.2.0" [[Requires]] deps = ["UUIDs"] -git-tree-sha1 = "28faf1c963ca1dc3ec87f166d92982e3c4a1f66d" +git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.1.0" +version = "1.0.1" [[SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" [[SIMDPirates]] deps = ["VectorizationBase"] -git-tree-sha1 = "450d163d3279a1d35e3aad3352a5167ef21b84a4" +git-tree-sha1 = "884df77ee290b45bcf9d2edcdb186afaf1281c39" uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a" -version = "0.8.25" +version = "0.8.23" [[SLEEFPirates]] deps = ["Libdl", "SIMDPirates", "VectorizationBase"] @@ -344,10 +327,6 @@ version = "0.5.5" [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -379,9 +358,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61" +git-tree-sha1 = "a6102b1f364befdb05746f386b67c6b7e3262c45" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.2" +version = "0.33.0" [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] @@ -404,36 +383,36 @@ deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +git-tree-sha1 = "d4bfa022cd30df012700cf380af2141961bb3bfb" uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" +version = "1.0.1" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [[VectorizationBase]] deps = ["CpuId", "Libdl", "LinearAlgebra"] -git-tree-sha1 = "03e2fbb479a1ea350398195b6fbf439bae0f8260" +git-tree-sha1 = "09a4d4896aac75199b332eb26a8d0a9e4655e34a" uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.12.33" +version = "0.12.30" [[ZipFile]] deps = ["Libdl", "Printf", "Zlib_jll"] -git-tree-sha1 = "c3a5637e27e914a7a445b8d0ad063d701931e9f7" +git-tree-sha1 = "254975fef2fc526583bb9b7c9420fe66ffe09f2f" uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.9.3" +version = "0.9.2" [[Zlib_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "320228915c8debb12cb434c59057290f0834dbf6" +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "d5bba6485811931e4b8958e2d7ca3738273ac468" uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.11+18" +version = "1.2.11+15" [[Zygote]] deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "b3f792937233abc61d4020a9e89a7f2bcb67333e" +git-tree-sha1 = "0079d92995b0fbcffd5d475d49ec8ca49375c471" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.5.9" +version = "0.5.4" [[ZygoteRules]] deps = ["MacroTools"] From 759de3fb8ccc459a3476a3134e4c2059d8aa45f8 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Wed, 21 Oct 2020 03:18:19 -0400 Subject: [PATCH 07/17] clean --- src/layers/recurrent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 44a127c282..c86e36f9de 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -52,7 +52,7 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) ``` """ -reset!(m::Recur) = (m.state = m.init) #modif to broadcast value of init into state +reset!(m::Recur) = (m.state = m.init) reset!(m) = foreach(reset!, functor(m)[1]) flip(f, xs) = reverse(f.(reverse(xs))) @@ -67,7 +67,7 @@ mutable struct RNNCell{F,A,V} end RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = - RNNCell(σ, init(out, in), init(out, out), init(out)) # remove h/init initialization + RNNCell(σ, init(out, in), init(out, out), init(out)) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b From dcf48080efbc4533f13f9bebe1bbad1c6d1fc0e5 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Thu, 22 Oct 2020 00:24:00 -0400 Subject: [PATCH 08/17] make RNN/LSTM/GRU Cells immutable - #1089 check that CUDNN drop solves for too many wrappers - #1259 --- src/layers/recurrent.jl | 6 +++--- test/rnn-test-jdb.jl | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index c86e36f9de..a1497832cc 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -59,7 +59,7 @@ flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN -mutable struct RNNCell{F,A,V} +struct RNNCell{F,A,V} σ::F Wi::A Wh::A @@ -96,7 +96,7 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM -mutable struct LSTMCell{A,V} +struct LSTMCell{A,V} Wi::A Wh::A b::V @@ -141,7 +141,7 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU -mutable struct GRUCell{A,V} +struct GRUCell{A,V} Wi::A Wh::A b::V diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl index ca0e6da67f..278358f93c 100644 --- a/test/rnn-test-jdb.jl +++ b/test/rnn-test-jdb.jl @@ -3,6 +3,14 @@ using Flux # using CUDA using Statistics: mean +################################################ +# Too many wrappers issue #1259 +################################################ +m = RNN(3,2) |> gpu +x = CUDA.ones(3,2) +gs = gradient(() -> sum(m(x)), params(m)) +gs[m.cell.Wi] + ######################## # RNN test gpu ######################## From a8a95d73a872fc52ddd6c361bdc872aa909e9601 Mon Sep 17 00:00:00 2001 From: jeremiedb Date: Thu, 22 Oct 2020 00:37:31 -0400 Subject: [PATCH 09/17] test changes --- test/rnn-test-jdb.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl index 278358f93c..58403764fc 100644 --- a/test/rnn-test-jdb.jl +++ b/test/rnn-test-jdb.jl @@ -1,6 +1,6 @@ using Revise using Flux -# using CUDA +using CUDA using Statistics: mean ################################################ @@ -16,9 +16,9 @@ gs[m.cell.Wi] ######################## # illustrate diverging behavior of GPU execution -feat = 16 -h_size = 32 -seq_len = 10 +feat = 32 +h_size = 64 +seq_len = 20 batch_size = 100 rnn = Chain(RNN(feat, h_size), From d6fdba2627410d72e1882208efeda7b762f9dcc6 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Wed, 28 Oct 2020 19:48:18 -0400 Subject: [PATCH 10/17] toy tests --- test/rnn-from-scractch.jl | 131 ++++++++++++++++++++++++++++++++++++++ test/rnn-test-jdb.jl | 10 +-- 2 files changed, 133 insertions(+), 8 deletions(-) create mode 100644 test/rnn-from-scractch.jl diff --git a/test/rnn-from-scractch.jl b/test/rnn-from-scractch.jl new file mode 100644 index 0000000000..7623652a25 --- /dev/null +++ b/test/rnn-from-scractch.jl @@ -0,0 +1,131 @@ +using Revise +using Flux +using Zygote +using CUDA +using Statistics: mean + +################################################ +# Define operators +################################################ +mutable struct MyRecur{T} + cell::T + init + state +end + +MyRecur(m, h=hidden(m)) = MyRecur(m, h, h) + +function (m::MyRecur)(xs...) + h, y = m.cell(m.state, xs...) + m.state = h + return y +end + +# Flux.@functor MyRecur cell, init +Flux.@functor MyRecur +Flux.trainable(a::MyRecur) = (a.cell, a.init) +# Flux.trainable(a::MyRecur) = (a.cell,) + +reset!(m::MyRecur) = (m.state = m.init) +reset!(m) = foreach(reset!, functor(m)[1]) + +# Vanilla RNN +struct MyRNNCell{F,A,V} + σ::F + Wi::A + Wh::A + b::V +end + +MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = +MyRNNCell(σ, init(out, in), init(out, out), init(out)) + +function (m::MyRNNCell)(h, x) + σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + h = σ.(Wi * x .+ Wh * h .+ b) + return h, h +end + +hidden(m::MyRNNCell) = m.h +Flux.@functor MyRNNCell + +MyRecur(m::MyRNNCell) = MyRecur(m, zeros(Float32, length(m.b)), zeros(Float32, length(m.b))) +MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) + +######################## +# RNN test gpu +######################## + +# illustrate diverging behavior of GPU execution +feat = 32 +h_size = 64 +seq_len = 20 +batch_size = 100 + +rnn = Chain(MyRNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x, :)) + +X = [rand(Float32, feat, batch_size) for i in 1:seq_len] +Y = rand(Float32, batch_size, seq_len) ./ 10 + +#### transfer to gpu #### +rnn_gpu = rnn |> gpu +X_gpu = gpu(X) +Y_gpu = gpu(Y) + +θ = Flux.params(rnn) +θ_gpu = Flux.params(rnn_gpu) +function loss(x, y) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) + Flux.reset!(rnn) + return l +end +function loss_gpu(x, y) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) + Flux.reset!(rnn_gpu) + return l +end + +opt = Descent(1e-2) +opt_gpu = Descent(1e-2) +for i in 1:25 + println("iter: ", i) + Flux.train!(loss, θ, [(X, Y)], opt) + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + println("loss_cpu: ", loss(X, Y)) + println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) + println("θ[3][1:2]: ", θ[3][1:2]) + println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2]) + println("θ[4][1:2]: ", θ[4][1:2]) + println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2]) + println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2]) + println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2]) +end +# loss should show an NA if run on GPU, but not on CPU after going over 100 iterations +typeof(loss(X, Y)) +typeof(loss_gpu(X_gpu, Y_gpu)) + +loss(X,Y) +loss_gpu(X_gpu, Y_gpu) +Flux.reset!(rnn) +Flux.train!(loss, θ, [(X, Y)], opt) + +# @code_warntype loss(X, Y) +# @code_warntype loss_gpu(X_gpu, Y_gpu) + +rnn.layers[1].init +rnn.layers[1].state +rnn_gpu.layers[1].init +rnn_gpu.layers[1].state + +θ[1] +θ[2] +θ[3] +θ[4] +θ[5] +θ[6] + +θ_gpu[4] .= 0 + +rnn_gpu(X_gpu[1]) diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl index 58403764fc..5f88e64b59 100644 --- a/test/rnn-test-jdb.jl +++ b/test/rnn-test-jdb.jl @@ -48,14 +48,8 @@ function loss_gpu(x,y) return l end -opt = ADAM(1e-4) -opt_gpu = ADAM(1e-4) -# loss(X,Y) -# loss_gpu(X_gpu, Y_gpu) -# Flux.train!(loss, θ, [(X,Y)], opt) -# Flux.train!(loss_gpu, θ_gpu, [(X_gpu,Y_gpu)], opt_gpu) -# loss(X,Y) -# loss_gpu(X_gpu, Y_gpu) +opt = Descent(1e-2) +opt_gpu = Descent(1e-2) for i in 1:50 println("iter: ", i) Flux.train!(loss, θ, [(X,Y)], opt) From eafee050b2f813432803f1f909afb4e28d675173 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Wed, 28 Oct 2020 20:30:26 -0400 Subject: [PATCH 11/17] put handle init state in cell --- test/rnn-from-scractch.jl | 21 +++++++-------- test/rnn-jdb.jl | 56 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 12 deletions(-) create mode 100644 test/rnn-jdb.jl diff --git a/test/rnn-from-scractch.jl b/test/rnn-from-scractch.jl index 7623652a25..8f6eff5f81 100644 --- a/test/rnn-from-scractch.jl +++ b/test/rnn-from-scractch.jl @@ -9,11 +9,10 @@ using Statistics: mean ################################################ mutable struct MyRecur{T} cell::T - init state end -MyRecur(m, h=hidden(m)) = MyRecur(m, h, h) +MyRecur(m) = MyRecur(m, init(m)) function (m::MyRecur)(xs...) h, y = m.cell(m.state, xs...) @@ -23,22 +22,22 @@ end # Flux.@functor MyRecur cell, init Flux.@functor MyRecur -Flux.trainable(a::MyRecur) = (a.cell, a.init) +Flux.trainable(a::MyRecur) = (a.cell,) # Flux.trainable(a::MyRecur) = (a.cell,) -reset!(m::MyRecur) = (m.state = m.init) +reset!(m::MyRecur) = (m.state = m.cell.init) reset!(m) = foreach(reset!, functor(m)[1]) # Vanilla RNN -struct MyRNNCell{F,A,V} +struct MyRNNCell{F,A,V,S} σ::F Wi::A Wh::A b::V + init::S end -MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = -MyRNNCell(σ, init(out, in), init(out, out), init(out)) +MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = MyRNNCell(σ, init(out, in), init(out, out), zeros(Float32, out), zeros(Float32, out)) function (m::MyRNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -46,10 +45,8 @@ function (m::MyRNNCell)(h, x) return h, h end -hidden(m::MyRNNCell) = m.h +init(m::MyRNNCell) = m.init Flux.@functor MyRNNCell - -MyRecur(m::MyRNNCell) = MyRecur(m, zeros(Float32, length(m.b)), zeros(Float32, length(m.b))) MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) ######################## @@ -87,8 +84,8 @@ function loss_gpu(x, y) return l end -opt = Descent(1e-2) -opt_gpu = Descent(1e-2) +opt = ADAM(1e-3) +opt_gpu = ADAM(1e-3) for i in 1:25 println("iter: ", i) Flux.train!(loss, θ, [(X, Y)], opt) diff --git a/test/rnn-jdb.jl b/test/rnn-jdb.jl new file mode 100644 index 0000000000..198f4810ef --- /dev/null +++ b/test/rnn-jdb.jl @@ -0,0 +1,56 @@ +using Revise +using Flux +using Statistics: mean + +# illustrate diverging behavior of GPU execution +feat = 32 +h_size = 64 +seq_len = 20 +batch_size = 100 + +rnn = Chain(RNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x, :)) + +X = [rand(Float32, feat, batch_size) for i in 1:seq_len] +Y = rand(Float32, batch_size, seq_len) ./ 10 + +#### transfer to gpu #### +rnn_gpu = rnn |> gpu +X_gpu = gpu(X) +Y_gpu = gpu(Y) + +θ = Flux.params(rnn) +θ_gpu = Flux.params(rnn_gpu) +function loss(x, y) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) + Flux.reset!(rnn) + return l +end +function loss_gpu(x, y) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) + Flux.reset!(rnn_gpu) + return l +end + +opt = Descent(1e-2) +opt_gpu = Descent(1e-2) +for i in 1:25 + println("iter: ", i) + Flux.train!(loss, θ, [(X, Y)], opt) + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + println("loss_cpu: ", loss(X, Y)) + println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) + println("θ[3][1:2]: ", θ[3][1:2]) + println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2]) + println("θ[4][1:2]: ", θ[4][1:2]) + println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2]) + println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2]) + println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2]) +end + + +θ[3] +θ[4] +θ[5] +θ[6] \ No newline at end of file From 3adf4710ae0bb722d2be5142b1d63ec04b3338db Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Wed, 28 Oct 2020 20:41:12 -0400 Subject: [PATCH 12/17] rnn toy test --- test/rnn-test-jdb.jl | 57 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test/rnn-test-jdb.jl diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl new file mode 100644 index 0000000000..07d04aa186 --- /dev/null +++ b/test/rnn-test-jdb.jl @@ -0,0 +1,57 @@ +using Revise +using Flux +using Statistics: mean + +# illustrate diverging behavior of GPU execution +feat = 32 +h_size = 64 +seq_len = 20 +batch_size = 100 + +rnn = Chain(RNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x, :)) + +X = [rand(Float32, feat, batch_size) for i in 1:seq_len] +Y = rand(Float32, batch_size, seq_len) ./ 10 + +#### transfer to gpu #### +rnn_gpu = rnn |> gpu +X_gpu = gpu(X) +Y_gpu = gpu(Y) + +θ = Flux.params(rnn) +θ_gpu = Flux.params(rnn_gpu) +function loss(x, y) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) + Flux.reset!(rnn) + return l +end +function loss_gpu(x, y) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) + Flux.reset!(rnn_gpu) + return l +end + +opt = ADAM(1e-3) +opt_gpu = ADAM(1e-3) +for i in 1:25 + println("iter: ", i) + Flux.train!(loss, θ, [(X, Y)], opt) + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + println("loss_cpu: ", loss(X, Y)) + println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) + println("θ[3][1:2]: ", θ[3][1:2]) + println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2]) + println("θ[4][1:2]: ", θ[4][1:2]) + println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2]) + println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2]) + println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2]) +end + +θ[1] +θ[2] +θ[3] +θ[4] +θ[5] +θ[6] \ No newline at end of file From 7ae661e2e2d1fa426fba5ea3ed9dad30047e1197 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Fri, 30 Oct 2020 23:10:50 -0400 Subject: [PATCH 13/17] functionnal RNNs --- .gitignore | 1 + src/layers/recurrent.jl | 41 +++++----- test/rnn-from-scractch.jl | 139 +++++++++++++++++++++++++++----- test/rnn-jdb.jl | 161 ++++++++++++++++++++++++++++++++++++-- test/rnn-minimal.jl | 151 +++++++++++++++++++++++++++++++++++ 5 files changed, 446 insertions(+), 47 deletions(-) create mode 100644 test/rnn-minimal.jl diff --git a/.gitignore b/.gitignore index eb18605cdd..d7820b381a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ docs/build/ docs/site/ deps +.vscode \ No newline at end of file diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index a1497832cc..11c4850803 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -27,18 +27,17 @@ rnn.state # 60 """ mutable struct Recur{T} cell::T - init state end -Recur(m, h = hidden(m)) = Recur(m, h, h) - function (m::Recur)(xs...) - m.state, y = m.cell(m.state, xs...) + h, y = m.cell(m.state, xs...) + m.state = h return y end -@functor Recur cell, init +@functor Recur +trainable(a::Recur) = (a.cell,) Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")") @@ -52,7 +51,7 @@ Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to: rnn.state = hidden(rnn.cell) ``` """ -reset!(m::Recur) = (m.state = m.init) +reset!(m::Recur) = (m.state = m.cell.state) reset!(m) = foreach(reset!, functor(m)[1]) flip(f, xs) = reverse(f.(reverse(xs))) @@ -64,10 +63,11 @@ struct RNNCell{F,A,V} Wi::A Wh::A b::V + state::V end -RNNCell(in::Integer, out::Integer, σ = tanh; init = glorot_uniform) = - RNNCell(σ, init(out, in), init(out, out), init(out)) +RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) = + RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out)) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -75,8 +75,6 @@ function (m::RNNCell)(h, x) return h, h end -hidden(m::RNNCell) = m.h - @functor RNNCell function Base.show(io::IO, l::RNNCell) @@ -91,20 +89,23 @@ end The most basic recurrent layer; essentially acts as a `Dense` layer, but with the output fed back into the input each time step. """ -Recur(m::RNNCell) = Recur(m, zeros(length(m.b)), zeros(length(m.b))) +Recur(m::RNNCell) = Recur(m, m.state) RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM -struct LSTMCell{A,V} +struct LSTMCell{A,V,S} Wi::A Wh::A b::V + state::S end function LSTMCell(in::Integer, out::Integer; - init = glorot_uniform) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4)) + init = glorot_uniform, + initb = zeros, + init_state = zeros) + cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out), init_state(out))) cell.b[gate(out, 2)] .= 1 return cell end @@ -135,8 +136,9 @@ recurrent layer. Behaves like an RNN but generally exhibits a longer memory span See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ -Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), - (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) +# Recur(m::LSTMCell) = Recur(m, (zeros(length(m.b)÷4), zeros(length(m.b)÷4)), +# (zeros(length(m.b)÷4), zeros(length(m.b)÷4))) +Recur(m::LSTMCell) = Recur(m, m.state) LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU @@ -145,10 +147,11 @@ struct GRUCell{A,V} Wi::A Wh::A b::V + state::V end -GRUCell(in, out; init = glorot_uniform) = - GRUCell(init(out * 3, in), init(out * 3, out), init(out * 3)) +GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) = + GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out)) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1) @@ -174,7 +177,7 @@ RNN but generally exhibits a longer memory span over sequences. See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) for a good overview of the internals. """ -Recur(m::GRUCell) = Recur(m, zeros(length(m.b)÷3), zeros(length(m.b)÷3)) +Recur(m::GRUCell) = Recur(m, m.state) GRU(a...; ka...) = Recur(GRUCell(a...; ka...)) @adjoint function Broadcast.broadcasted(f::Recur, args...) diff --git a/test/rnn-from-scractch.jl b/test/rnn-from-scractch.jl index 8f6eff5f81..bf47247d17 100644 --- a/test/rnn-from-scractch.jl +++ b/test/rnn-from-scractch.jl @@ -1,7 +1,8 @@ using Revise using Flux -using Zygote +# using Zygote using CUDA +using Random using Statistics: mean ################################################ @@ -12,29 +13,27 @@ mutable struct MyRecur{T} state end -MyRecur(m) = MyRecur(m, init(m)) - function (m::MyRecur)(xs...) h, y = m.cell(m.state, xs...) m.state = h return y end -# Flux.@functor MyRecur cell, init Flux.@functor MyRecur Flux.trainable(a::MyRecur) = (a.cell,) -# Flux.trainable(a::MyRecur) = (a.cell,) -reset!(m::MyRecur) = (m.state = m.cell.init) -reset!(m) = foreach(reset!, functor(m)[1]) +function reset!(m::MyRecur) + m.state = m.cell.init +end +reset!(m) = foreach(reset!, Flux.functor(m)[1]) # Vanilla RNN -struct MyRNNCell{F,A,V,S} +mutable struct MyRNNCell{F,A,V} σ::F Wi::A Wh::A b::V - init::S + init::V end MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = MyRNNCell(σ, init(out, in), init(out, out), zeros(Float32, out), zeros(Float32, out)) @@ -45,8 +44,10 @@ function (m::MyRNNCell)(h, x) return h, h end -init(m::MyRNNCell) = m.init +# init(m::MyRNNCell) = m.init Flux.@functor MyRNNCell +Flux.trainable(a::MyRNNCell) = (a.Wi, a.Wh, a.b, a.init) +MyRecur(m::MyRNNCell) = MyRecur(m, m.init) MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) ######################## @@ -57,7 +58,7 @@ MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) feat = 32 h_size = 64 seq_len = 20 -batch_size = 100 +batch_size = 128 rnn = Chain(MyRNN(feat, h_size), Dense(h_size, 1, σ), @@ -74,19 +75,21 @@ Y_gpu = gpu(Y) θ = Flux.params(rnn) θ_gpu = Flux.params(rnn_gpu) function loss(x, y) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) - Flux.reset!(rnn) + reset!(rnn) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) + # reset!(rnn) return l end function loss_gpu(x, y) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) - Flux.reset!(rnn_gpu) + reset!(rnn_gpu) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) + # reset!(rnn_gpu) return l end opt = ADAM(1e-3) opt_gpu = ADAM(1e-3) -for i in 1:25 +for i in 1:50 println("iter: ", i) Flux.train!(loss, θ, [(X, Y)], opt) Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) @@ -105,17 +108,22 @@ typeof(loss_gpu(X_gpu, Y_gpu)) loss(X,Y) loss_gpu(X_gpu, Y_gpu) -Flux.reset!(rnn) +reset!(rnn) +reset!(rnn_gpu) Flux.train!(loss, θ, [(X, Y)], opt) +Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) +loss(X,Y) +loss_gpu(X_gpu, Y_gpu) # @code_warntype loss(X, Y) # @code_warntype loss_gpu(X_gpu, Y_gpu) -rnn.layers[1].init rnn.layers[1].state -rnn_gpu.layers[1].init rnn_gpu.layers[1].state +rnn_gpu.layers[1].cell.init +rnn_gpu.layers[1].cell.init + θ[1] θ[2] θ[3] @@ -123,6 +131,95 @@ rnn_gpu.layers[1].state θ[5] θ[6] -θ_gpu[4] .= 0 +θ_gpu[1] +θ_gpu[2] +θ_gpu[3] +θ_gpu[4] +θ_gpu[5] + + +function debug_rnn(seed=123) + + # illustrate diverging behavior of GPU execution + feat = 2 + h_size = 3 + seq_len = 1 + batch_size = 1 + + function loss(x, y) + # reset!(rnn) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) + # reset!(rnn) + return l + end + function loss_gpu(x, y) + # reset!(rnn_gpu) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) + # reset!(rnn_gpu) + return l + end + + Random.seed!(seed) + rnn = Chain(MyRNN(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x, :)) + + X = [rand(Float32, feat, batch_size) for i in 1:seq_len] + Y = rand(Float32, batch_size, seq_len) ./ 10 + + # transfer to gpu #### + rnn_gpu = rnn |> gpu + X_gpu = gpu(X) + Y_gpu = gpu(Y) + + θ = Flux.params(rnn) + θ_gpu = Flux.params(rnn_gpu) + + opt = Descent(1e-2) + opt_gpu = Descent(1e-2) + + l_cpu = loss(X,Y) + l_gpu = loss_gpu(X_gpu,Y_gpu) + + # println("loss_ratio: ", l_cpu / l_gpu - 1) + + println("CPU") + println("loss: ", loss(X,Y)) + # println("X: ", X) + # println("Y: ", Y) + debug_train!(loss, θ, (X,Y)) + println() + + println("GPU") + println("loss: ", loss_gpu(X_gpu,Y_gpu)) + # println("X_gpu: ", X_gpu) + # println("Y_gpu: ", Y_gpu) + debug_train!(loss_gpu, θ_gpu, (X_gpu,Y_gpu)) + println() +end + +function debug_train!(loss, ps, d) + gs = gradient(ps) do + loss(d...) + end + # x = first(ps) + @show ps[1] gs[ps[1]] + @show ps[2] gs[ps[2]] + @show ps[3] gs[ps[3]] + @show ps[4] gs[ps[4]] + @show ps[5] gs[ps[5]] +end + +# seed 19 and 77 are buggy +debug_rnn(15) +debug_rnn(19) +debug_rnn(77) -rnn_gpu(X_gpu[1]) +for i in 101:200 + println(i) + debug_rnn(i) +end + +gs = gradient(θ) do + loss(X,Y) +end diff --git a/test/rnn-jdb.jl b/test/rnn-jdb.jl index 198f4810ef..180cd77a9d 100644 --- a/test/rnn-jdb.jl +++ b/test/rnn-jdb.jl @@ -5,8 +5,8 @@ using Statistics: mean # illustrate diverging behavior of GPU execution feat = 32 h_size = 64 -seq_len = 20 -batch_size = 100 +seq_len = 50 +batch_size = 256 rnn = Chain(RNN(feat, h_size), Dense(h_size, 1, σ), @@ -23,18 +23,18 @@ Y_gpu = gpu(Y) θ = Flux.params(rnn) θ_gpu = Flux.params(rnn_gpu) function loss(x, y) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) Flux.reset!(rnn) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) return l end function loss_gpu(x, y) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) Flux.reset!(rnn_gpu) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) return l end -opt = Descent(1e-2) -opt_gpu = Descent(1e-2) +opt = ADAM(1e-3) +opt_gpu = ADAM(1e-3) for i in 1:25 println("iter: ", i) Flux.train!(loss, θ, [(X, Y)], opt) @@ -50,7 +50,154 @@ for i in 1:25 end +function speed_cpu(n=10) + for i in 1:n + Flux.train!(loss, θ, [(X, Y)], opt) + end + return loss(X, Y) +end + +function speed_gpu(n=10) + for i in 1:n + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + end + return loss_gpu(X_gpu, Y_gpu) +end + +@time speed_cpu(100) +@time speed_gpu(100) + θ[3] θ[4] θ[5] -θ[6] \ No newline at end of file +θ[6] + + +##################################### +# LSTM +##################################### +feat = 32 +h_size = 64 +seq_len = 50 +batch_size = 256 + +rnn = Chain(LSTM(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x, :)) + +X = [rand(Float32, feat, batch_size) for i in 1:seq_len] +Y = rand(Float32, batch_size, seq_len) ./ 10 + +#### transfer to gpu #### +rnn_gpu = rnn |> gpu +X_gpu = gpu(X) +Y_gpu = gpu(Y) + +θ = Flux.params(rnn) +θ_gpu = Flux.params(rnn_gpu) +function loss(x, y) + Flux.reset!(rnn) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) + return l +end +function loss_gpu(x, y) + Flux.reset!(rnn_gpu) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) + return l +end + +opt = ADAM(1e-3) +opt_gpu = ADAM(1e-3) + +for i in 1:5 + println("iter: ", i) + Flux.train!(loss, θ, [(X, Y)], opt) + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + println("loss_cpu: ", loss(X, Y)) + println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) +end + + +function speed_cpu(n=10) + for i in 1:n + Flux.train!(loss, θ, [(X, Y)], opt) + end + return loss(X, Y) +end + +function speed_gpu(n=10) + for i in 1:n + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + end + return loss_gpu(X_gpu, Y_gpu) +end + +@code_warntype rnn(X) + +@time speed_cpu(100) +@time speed_gpu(100) + + +##################################### +# GRU +##################################### +feat = 32 +h_size = 64 +seq_len = 50 +batch_size = 256 + +rnn = Chain(GRU(feat, h_size), + Dense(h_size, 1, σ), + x -> reshape(x, :)) + +X = [rand(Float32, feat, batch_size) for i in 1:seq_len] +Y = rand(Float32, batch_size, seq_len) ./ 10 + +#### transfer to gpu #### +rnn_gpu = rnn |> gpu +X_gpu = gpu(X) +Y_gpu = gpu(Y) + +θ = Flux.params(rnn) +θ_gpu = Flux.params(rnn_gpu) +function loss(x, y) + Flux.reset!(rnn) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) + return l +end +function loss_gpu(x, y) + Flux.reset!(rnn_gpu) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) + return l +end + +opt = ADAM(1e-3) +opt_gpu = ADAM(1e-3) + +for i in 1:5 + println("iter: ", i) + Flux.train!(loss, θ, [(X, Y)], opt) + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + println("loss_cpu: ", loss(X, Y)) + println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) +end + + +function speed_cpu(n=10) + for i in 1:n + Flux.train!(loss, θ, [(X, Y)], opt) + end + return loss(X, Y) +end + +function speed_gpu(n=10) + for i in 1:n + Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) + end + return loss_gpu(X_gpu, Y_gpu) +end + +@code_warntype rnn(X) + +@time speed_cpu(100) +@time speed_gpu(100) \ No newline at end of file diff --git a/test/rnn-minimal.jl b/test/rnn-minimal.jl new file mode 100644 index 0000000000..6e915aa2d9 --- /dev/null +++ b/test/rnn-minimal.jl @@ -0,0 +1,151 @@ +using Revise +using Flux +# using Zygote +# using CUDA +using Random +using Statistics: mean + +################################################ +# Define operators +################################################ +struct MyRecur{T} + cell::T + state +end + +function (m::MyRecur)(x) + y = m.cell(m.state, x) + # m.state = h + return y +end + +# Flux.@functor MyRecur cell, init +# Flux.@functor MyRecur (cell,) +Flux.@functor MyRecur +Flux.trainable(a::MyRecur) = (a.cell,) +# Flux.trainable(a::MyRecur) = (a.cell,) + +function reset!(m::MyRecur) + m.state = m.cell.init + # m.state .*= 0.0 +end +reset!(m) = foreach(reset!, Flux.functor(m)[1]) + +# Vanilla RNN +mutable struct MyRNNCell{F,A,V} + σ::F + Wi::A + Wh::A + b::V + init::V +end + +MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = MyRNNCell(σ, init(out, in), init(out, out), init(out), zeros(Float32, out)) + +function (m::MyRNNCell)(h, x) + σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + h = Wi * x .+ Wh * h .+ b + # h = σ.(Wi * x .+ Wh * h .+ b) + # return h, h + return h +end + +# init(m::MyRNNCell) = m.init +Flux.@functor MyRNNCell +Flux.trainable(a::MyRNNCell) = (a.Wi, a.Wh, a.b) +MyRecur(m::MyRNNCell) = MyRecur(m, Flux.glorot_uniform(length(m.init))) +MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) + + +# RNN test gpu +function debug_rnn(seed=123) + + # illustrate diverging behavior of GPU execution + feat = 2 + h_size = 3 + seq_len = 1 + batch_size = 1 + + function loss(x, y) + # reset!(rnn) + # l1 = (Flux.stack(map(rnn, x), 2) .- y) + # println("l1: ", l1) + l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) + # l = mean((Flux.stack(map(rnn, x), 2) .- y)) + # l = mean((rnn(x[1]) .- y).^2f0) + # l = mean(rnn(x[1])) + # reset!(rnn) + return l + end + function loss_gpu(x, y) + # reset!(rnn_gpu) + # l1 = (Flux.stack(map(rnn_gpu, x), 2) .- y) + # println("l1 gpu: ", l1) + l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) + # l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y)) + # l = mean(rnn_gpu(x[1] .- y)) + # reset!(rnn_gpu) + return l + end + + Random.seed!(seed) + rnn = Chain(MyRNN(feat, h_size), + Dense(h_size, 1), + x -> reshape(x, :)) + + X = [rand(Float32, feat, batch_size) for i in 1:seq_len] + Y = rand(Float32, batch_size, seq_len) ./ 10 + + # transfer to gpu #### + rnn_gpu = rnn |> gpu + X_gpu = gpu(X) + Y_gpu = gpu(Y) + + θ = Flux.params(rnn) + θ_gpu = Flux.params(rnn_gpu) + + opt = Descent(1e-2) + opt_gpu = Descent(1e-2) + + l_cpu = loss(X,Y) + l_gpu = loss_gpu(X_gpu,Y_gpu) + + # println("loss_ratio: ", l_cpu / l_gpu - 1) + + println("CPU") + println("loss: ", loss(X,Y)) + # println("X: ", X) + # println("Y: ", Y) + debug_train!(loss, θ, (X,Y)) + println() + + println("GPU") + println("loss: ", loss_gpu(X_gpu,Y_gpu)) + # println("X_gpu: ", X_gpu) + # println("Y_gpu: ", Y_gpu) + debug_train!(loss_gpu, θ_gpu, (X_gpu,Y_gpu)) + println() +end + +function debug_train!(loss, ps, d) + gs = gradient(ps) do + loss(d...) + end + # x = first(ps) + @show ps[1] gs[ps[1]] + @show ps[2] gs[ps[2]] + @show ps[3] gs[ps[3]] + @show ps[4] gs[ps[4]] + @show ps[5] gs[ps[5]] +end + +# seed 19 and 77 are buggy +debug_rnn(15) +debug_rnn(19) +debug_rnn(77) + +for i in 101:200 + println(i) + debug_rnn(i) +end + From deafdcafb66cc7c2bdcda962a51e0b0f6831c430 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Fri, 30 Oct 2020 23:24:14 -0400 Subject: [PATCH 14/17] check warntype --- test/rnn-jdb.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/rnn-jdb.jl b/test/rnn-jdb.jl index 180cd77a9d..bdd92ee6ef 100644 --- a/test/rnn-jdb.jl +++ b/test/rnn-jdb.jl @@ -64,6 +64,8 @@ function speed_gpu(n=10) return loss_gpu(X_gpu, Y_gpu) end +@code_warntype rnn(X[1]) + @time speed_cpu(100) @time speed_gpu(100) @@ -132,7 +134,7 @@ function speed_gpu(n=10) return loss_gpu(X_gpu, Y_gpu) end -@code_warntype rnn(X) +@code_warntype rnn(X[1]) @time speed_cpu(100) @time speed_gpu(100) @@ -197,7 +199,7 @@ function speed_gpu(n=10) return loss_gpu(X_gpu, Y_gpu) end -@code_warntype rnn(X) +@code_warntype rnn(X[1]) @time speed_cpu(100) @time speed_gpu(100) \ No newline at end of file From 6fd355f0b324cb70226cbd20b9078af1e8c19910 Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Sat, 31 Oct 2020 00:18:50 -0400 Subject: [PATCH 15/17] stable type output clean up test files --- src/layers/recurrent.jl | 18 +-- test/layers/recurrent.jl | 4 +- test/rnn-from-scractch.jl | 225 -------------------------------------- test/rnn-jdb.jl | 205 ---------------------------------- test/rnn-minimal.jl | 151 ------------------------- test/rnn-test-jdb.jl | 186 ------------------------------- 6 files changed, 11 insertions(+), 778 deletions(-) delete mode 100644 test/rnn-from-scractch.jl delete mode 100644 test/rnn-jdb.jl delete mode 100644 test/rnn-minimal.jl delete mode 100644 test/rnn-test-jdb.jl diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 11c4850803..70ba5e6988 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -25,9 +25,9 @@ rnn.(1:10) # apply to a sequence rnn.state # 60 ``` """ -mutable struct Recur{T} +mutable struct Recur{T,S} cell::T - state + state::S end function (m::Recur)(xs...) @@ -58,16 +58,16 @@ flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN -struct RNNCell{F,A,V} +struct RNNCell{F,A,V,S} σ::F Wi::A Wh::A b::V - state::V + state::S end RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros) = - RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out)) + RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1)) function (m::RNNCell)(h, x) σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b @@ -105,7 +105,7 @@ function LSTMCell(in::Integer, out::Integer; init = glorot_uniform, initb = zeros, init_state = zeros) - cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out), init_state(out))) + cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1))) cell.b[gate(out, 2)] .= 1 return cell end @@ -143,15 +143,15 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU -struct GRUCell{A,V} +struct GRUCell{A,V,S} Wi::A Wh::A b::V - state::V + state::S end GRUCell(in, out; init = glorot_uniform, initb = zeros, init_state = zeros) = - GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out)) + GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1)) function (m::GRUCell)(h, x) b, o = m.b, size(h, 1) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 2bb093fc96..0691f5929e 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -1,6 +1,6 @@ # Ref FluxML/Flux.jl#1209 @testset "BPTT" begin - seq = [rand(2) for i = 1:3] + seq = [rand(Float32, (2,1)) for i = 1:3] for r ∈ [RNN,] rnn = r(2,3) Flux.reset!(rnn) @@ -11,7 +11,7 @@ bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh * tanh.(rnn.cell.Wi * seq[2] + Wh * tanh.(rnn.cell.Wi * seq[1] + - Wh * rnn.init + Wh * rnn.cell.state + rnn.cell.b) + rnn.cell.b) + rnn.cell.b)), diff --git a/test/rnn-from-scractch.jl b/test/rnn-from-scractch.jl deleted file mode 100644 index bf47247d17..0000000000 --- a/test/rnn-from-scractch.jl +++ /dev/null @@ -1,225 +0,0 @@ -using Revise -using Flux -# using Zygote -using CUDA -using Random -using Statistics: mean - -################################################ -# Define operators -################################################ -mutable struct MyRecur{T} - cell::T - state -end - -function (m::MyRecur)(xs...) - h, y = m.cell(m.state, xs...) - m.state = h - return y -end - -Flux.@functor MyRecur -Flux.trainable(a::MyRecur) = (a.cell,) - -function reset!(m::MyRecur) - m.state = m.cell.init -end -reset!(m) = foreach(reset!, Flux.functor(m)[1]) - -# Vanilla RNN -mutable struct MyRNNCell{F,A,V} - σ::F - Wi::A - Wh::A - b::V - init::V -end - -MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = MyRNNCell(σ, init(out, in), init(out, out), zeros(Float32, out), zeros(Float32, out)) - -function (m::MyRNNCell)(h, x) - σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b - h = σ.(Wi * x .+ Wh * h .+ b) - return h, h -end - -# init(m::MyRNNCell) = m.init -Flux.@functor MyRNNCell -Flux.trainable(a::MyRNNCell) = (a.Wi, a.Wh, a.b, a.init) -MyRecur(m::MyRNNCell) = MyRecur(m, m.init) -MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) - -######################## -# RNN test gpu -######################## - -# illustrate diverging behavior of GPU execution -feat = 32 -h_size = 64 -seq_len = 20 -batch_size = 128 - -rnn = Chain(MyRNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x, :)) - -X = [rand(Float32, feat, batch_size) for i in 1:seq_len] -Y = rand(Float32, batch_size, seq_len) ./ 10 - -#### transfer to gpu #### -rnn_gpu = rnn |> gpu -X_gpu = gpu(X) -Y_gpu = gpu(Y) - -θ = Flux.params(rnn) -θ_gpu = Flux.params(rnn_gpu) -function loss(x, y) - reset!(rnn) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) - # reset!(rnn) - return l -end -function loss_gpu(x, y) - reset!(rnn_gpu) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) - # reset!(rnn_gpu) - return l -end - -opt = ADAM(1e-3) -opt_gpu = ADAM(1e-3) -for i in 1:50 - println("iter: ", i) - Flux.train!(loss, θ, [(X, Y)], opt) - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - println("loss_cpu: ", loss(X, Y)) - println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) - println("θ[3][1:2]: ", θ[3][1:2]) - println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2]) - println("θ[4][1:2]: ", θ[4][1:2]) - println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2]) - println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2]) - println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2]) -end -# loss should show an NA if run on GPU, but not on CPU after going over 100 iterations -typeof(loss(X, Y)) -typeof(loss_gpu(X_gpu, Y_gpu)) - -loss(X,Y) -loss_gpu(X_gpu, Y_gpu) -reset!(rnn) -reset!(rnn_gpu) -Flux.train!(loss, θ, [(X, Y)], opt) -Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) -loss(X,Y) -loss_gpu(X_gpu, Y_gpu) - -# @code_warntype loss(X, Y) -# @code_warntype loss_gpu(X_gpu, Y_gpu) - -rnn.layers[1].state -rnn_gpu.layers[1].state - -rnn_gpu.layers[1].cell.init -rnn_gpu.layers[1].cell.init - -θ[1] -θ[2] -θ[3] -θ[4] -θ[5] -θ[6] - -θ_gpu[1] -θ_gpu[2] -θ_gpu[3] -θ_gpu[4] -θ_gpu[5] - - -function debug_rnn(seed=123) - - # illustrate diverging behavior of GPU execution - feat = 2 - h_size = 3 - seq_len = 1 - batch_size = 1 - - function loss(x, y) - # reset!(rnn) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) - # reset!(rnn) - return l - end - function loss_gpu(x, y) - # reset!(rnn_gpu) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) - # reset!(rnn_gpu) - return l - end - - Random.seed!(seed) - rnn = Chain(MyRNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x, :)) - - X = [rand(Float32, feat, batch_size) for i in 1:seq_len] - Y = rand(Float32, batch_size, seq_len) ./ 10 - - # transfer to gpu #### - rnn_gpu = rnn |> gpu - X_gpu = gpu(X) - Y_gpu = gpu(Y) - - θ = Flux.params(rnn) - θ_gpu = Flux.params(rnn_gpu) - - opt = Descent(1e-2) - opt_gpu = Descent(1e-2) - - l_cpu = loss(X,Y) - l_gpu = loss_gpu(X_gpu,Y_gpu) - - # println("loss_ratio: ", l_cpu / l_gpu - 1) - - println("CPU") - println("loss: ", loss(X,Y)) - # println("X: ", X) - # println("Y: ", Y) - debug_train!(loss, θ, (X,Y)) - println() - - println("GPU") - println("loss: ", loss_gpu(X_gpu,Y_gpu)) - # println("X_gpu: ", X_gpu) - # println("Y_gpu: ", Y_gpu) - debug_train!(loss_gpu, θ_gpu, (X_gpu,Y_gpu)) - println() -end - -function debug_train!(loss, ps, d) - gs = gradient(ps) do - loss(d...) - end - # x = first(ps) - @show ps[1] gs[ps[1]] - @show ps[2] gs[ps[2]] - @show ps[3] gs[ps[3]] - @show ps[4] gs[ps[4]] - @show ps[5] gs[ps[5]] -end - -# seed 19 and 77 are buggy -debug_rnn(15) -debug_rnn(19) -debug_rnn(77) - -for i in 101:200 - println(i) - debug_rnn(i) -end - -gs = gradient(θ) do - loss(X,Y) -end diff --git a/test/rnn-jdb.jl b/test/rnn-jdb.jl deleted file mode 100644 index bdd92ee6ef..0000000000 --- a/test/rnn-jdb.jl +++ /dev/null @@ -1,205 +0,0 @@ -using Revise -using Flux -using Statistics: mean - -# illustrate diverging behavior of GPU execution -feat = 32 -h_size = 64 -seq_len = 50 -batch_size = 256 - -rnn = Chain(RNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x, :)) - -X = [rand(Float32, feat, batch_size) for i in 1:seq_len] -Y = rand(Float32, batch_size, seq_len) ./ 10 - -#### transfer to gpu #### -rnn_gpu = rnn |> gpu -X_gpu = gpu(X) -Y_gpu = gpu(Y) - -θ = Flux.params(rnn) -θ_gpu = Flux.params(rnn_gpu) -function loss(x, y) - Flux.reset!(rnn) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) - return l -end -function loss_gpu(x, y) - Flux.reset!(rnn_gpu) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) - return l -end - -opt = ADAM(1e-3) -opt_gpu = ADAM(1e-3) -for i in 1:25 - println("iter: ", i) - Flux.train!(loss, θ, [(X, Y)], opt) - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - println("loss_cpu: ", loss(X, Y)) - println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) - println("θ[3][1:2]: ", θ[3][1:2]) - println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2]) - println("θ[4][1:2]: ", θ[4][1:2]) - println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2]) - println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2]) - println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2]) -end - - -function speed_cpu(n=10) - for i in 1:n - Flux.train!(loss, θ, [(X, Y)], opt) - end - return loss(X, Y) -end - -function speed_gpu(n=10) - for i in 1:n - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - end - return loss_gpu(X_gpu, Y_gpu) -end - -@code_warntype rnn(X[1]) - -@time speed_cpu(100) -@time speed_gpu(100) - -θ[3] -θ[4] -θ[5] -θ[6] - - -##################################### -# LSTM -##################################### -feat = 32 -h_size = 64 -seq_len = 50 -batch_size = 256 - -rnn = Chain(LSTM(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x, :)) - -X = [rand(Float32, feat, batch_size) for i in 1:seq_len] -Y = rand(Float32, batch_size, seq_len) ./ 10 - -#### transfer to gpu #### -rnn_gpu = rnn |> gpu -X_gpu = gpu(X) -Y_gpu = gpu(Y) - -θ = Flux.params(rnn) -θ_gpu = Flux.params(rnn_gpu) -function loss(x, y) - Flux.reset!(rnn) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) - return l -end -function loss_gpu(x, y) - Flux.reset!(rnn_gpu) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) - return l -end - -opt = ADAM(1e-3) -opt_gpu = ADAM(1e-3) - -for i in 1:5 - println("iter: ", i) - Flux.train!(loss, θ, [(X, Y)], opt) - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - println("loss_cpu: ", loss(X, Y)) - println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) -end - - -function speed_cpu(n=10) - for i in 1:n - Flux.train!(loss, θ, [(X, Y)], opt) - end - return loss(X, Y) -end - -function speed_gpu(n=10) - for i in 1:n - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - end - return loss_gpu(X_gpu, Y_gpu) -end - -@code_warntype rnn(X[1]) - -@time speed_cpu(100) -@time speed_gpu(100) - - -##################################### -# GRU -##################################### -feat = 32 -h_size = 64 -seq_len = 50 -batch_size = 256 - -rnn = Chain(GRU(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x, :)) - -X = [rand(Float32, feat, batch_size) for i in 1:seq_len] -Y = rand(Float32, batch_size, seq_len) ./ 10 - -#### transfer to gpu #### -rnn_gpu = rnn |> gpu -X_gpu = gpu(X) -Y_gpu = gpu(Y) - -θ = Flux.params(rnn) -θ_gpu = Flux.params(rnn_gpu) -function loss(x, y) - Flux.reset!(rnn) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) - return l -end -function loss_gpu(x, y) - Flux.reset!(rnn_gpu) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) - return l -end - -opt = ADAM(1e-3) -opt_gpu = ADAM(1e-3) - -for i in 1:5 - println("iter: ", i) - Flux.train!(loss, θ, [(X, Y)], opt) - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - println("loss_cpu: ", loss(X, Y)) - println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) -end - - -function speed_cpu(n=10) - for i in 1:n - Flux.train!(loss, θ, [(X, Y)], opt) - end - return loss(X, Y) -end - -function speed_gpu(n=10) - for i in 1:n - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - end - return loss_gpu(X_gpu, Y_gpu) -end - -@code_warntype rnn(X[1]) - -@time speed_cpu(100) -@time speed_gpu(100) \ No newline at end of file diff --git a/test/rnn-minimal.jl b/test/rnn-minimal.jl deleted file mode 100644 index 6e915aa2d9..0000000000 --- a/test/rnn-minimal.jl +++ /dev/null @@ -1,151 +0,0 @@ -using Revise -using Flux -# using Zygote -# using CUDA -using Random -using Statistics: mean - -################################################ -# Define operators -################################################ -struct MyRecur{T} - cell::T - state -end - -function (m::MyRecur)(x) - y = m.cell(m.state, x) - # m.state = h - return y -end - -# Flux.@functor MyRecur cell, init -# Flux.@functor MyRecur (cell,) -Flux.@functor MyRecur -Flux.trainable(a::MyRecur) = (a.cell,) -# Flux.trainable(a::MyRecur) = (a.cell,) - -function reset!(m::MyRecur) - m.state = m.cell.init - # m.state .*= 0.0 -end -reset!(m) = foreach(reset!, Flux.functor(m)[1]) - -# Vanilla RNN -mutable struct MyRNNCell{F,A,V} - σ::F - Wi::A - Wh::A - b::V - init::V -end - -MyRNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform) = MyRNNCell(σ, init(out, in), init(out, out), init(out), zeros(Float32, out)) - -function (m::MyRNNCell)(h, x) - σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b - h = Wi * x .+ Wh * h .+ b - # h = σ.(Wi * x .+ Wh * h .+ b) - # return h, h - return h -end - -# init(m::MyRNNCell) = m.init -Flux.@functor MyRNNCell -Flux.trainable(a::MyRNNCell) = (a.Wi, a.Wh, a.b) -MyRecur(m::MyRNNCell) = MyRecur(m, Flux.glorot_uniform(length(m.init))) -MyRNN(a...; ka...) = MyRecur(MyRNNCell(a...; ka...)) - - -# RNN test gpu -function debug_rnn(seed=123) - - # illustrate diverging behavior of GPU execution - feat = 2 - h_size = 3 - seq_len = 1 - batch_size = 1 - - function loss(x, y) - # reset!(rnn) - # l1 = (Flux.stack(map(rnn, x), 2) .- y) - # println("l1: ", l1) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2) - # l = mean((Flux.stack(map(rnn, x), 2) .- y)) - # l = mean((rnn(x[1]) .- y).^2f0) - # l = mean(rnn(x[1])) - # reset!(rnn) - return l - end - function loss_gpu(x, y) - # reset!(rnn_gpu) - # l1 = (Flux.stack(map(rnn_gpu, x), 2) .- y) - # println("l1 gpu: ", l1) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2) - # l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y)) - # l = mean(rnn_gpu(x[1] .- y)) - # reset!(rnn_gpu) - return l - end - - Random.seed!(seed) - rnn = Chain(MyRNN(feat, h_size), - Dense(h_size, 1), - x -> reshape(x, :)) - - X = [rand(Float32, feat, batch_size) for i in 1:seq_len] - Y = rand(Float32, batch_size, seq_len) ./ 10 - - # transfer to gpu #### - rnn_gpu = rnn |> gpu - X_gpu = gpu(X) - Y_gpu = gpu(Y) - - θ = Flux.params(rnn) - θ_gpu = Flux.params(rnn_gpu) - - opt = Descent(1e-2) - opt_gpu = Descent(1e-2) - - l_cpu = loss(X,Y) - l_gpu = loss_gpu(X_gpu,Y_gpu) - - # println("loss_ratio: ", l_cpu / l_gpu - 1) - - println("CPU") - println("loss: ", loss(X,Y)) - # println("X: ", X) - # println("Y: ", Y) - debug_train!(loss, θ, (X,Y)) - println() - - println("GPU") - println("loss: ", loss_gpu(X_gpu,Y_gpu)) - # println("X_gpu: ", X_gpu) - # println("Y_gpu: ", Y_gpu) - debug_train!(loss_gpu, θ_gpu, (X_gpu,Y_gpu)) - println() -end - -function debug_train!(loss, ps, d) - gs = gradient(ps) do - loss(d...) - end - # x = first(ps) - @show ps[1] gs[ps[1]] - @show ps[2] gs[ps[2]] - @show ps[3] gs[ps[3]] - @show ps[4] gs[ps[4]] - @show ps[5] gs[ps[5]] -end - -# seed 19 and 77 are buggy -debug_rnn(15) -debug_rnn(19) -debug_rnn(77) - -for i in 101:200 - println(i) - debug_rnn(i) -end - diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl deleted file mode 100644 index 5f88e64b59..0000000000 --- a/test/rnn-test-jdb.jl +++ /dev/null @@ -1,186 +0,0 @@ -using Revise -using Flux -using CUDA -using Statistics: mean - -################################################ -# Too many wrappers issue #1259 -################################################ -m = RNN(3,2) |> gpu -x = CUDA.ones(3,2) -gs = gradient(() -> sum(m(x)), params(m)) -gs[m.cell.Wi] - -######################## -# RNN test gpu -######################## - -# illustrate diverging behavior of GPU execution -feat = 32 -h_size = 64 -seq_len = 20 -batch_size = 100 - -rnn = Chain(RNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x,:)) - -X = [rand(feat, batch_size) for i in 1:seq_len] -Y = rand(batch_size, seq_len) ./ 10 - -###################################### -#### do not run for cpu execution #### -rnn_gpu = rnn |> gpu -X_gpu = gpu(X) -Y_gpu = gpu(Y) -###################################### - -θ = Flux.params(rnn) -θ_gpu = Flux.params(rnn_gpu) -function loss(x,y) - l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) - Flux.reset!(rnn) - return l -end -function loss_gpu(x,y) - l = mean((Flux.stack(map(rnn_gpu, x),2) .- y) .^ 2f0) - Flux.reset!(rnn_gpu) - return l -end - -opt = Descent(1e-2) -opt_gpu = Descent(1e-2) -for i in 1:50 - println("iter: ", i) - Flux.train!(loss, θ, [(X,Y)], opt) - Flux.train!(loss_gpu, θ_gpu, [(X_gpu,Y_gpu)], opt_gpu) - println("loss_cpu: ", loss(X, Y)) - println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) -end -# loss should show an NA if run on GPU, but not on CPU after going over 100 iterations -loss(X,Y) -loss_gpu(X_gpu, Y_gpu) -Flux.reset!(rnn) -Flux.train!(loss, θ, [(X,Y)], opt) - -rnn.layers[1].init -rnn.layers[1].state -rnn_gpu.layers[1].init -rnn_gpu.layers[1].state - -θ[1] -θ[2] -θ[3] -θ[4] - -###################### -# basic test -###################### -seq = [[1,2] ./ 10 for b in 1:3] -seq = hcat(seq...) -seq = [seq for i = 1:4] - -m = RNN(2, 5) -m.cell.Wi .= [0.1 0] -m.cell.Wh .= [0.2] -m.cell.b .= 1.1 -# m.cell.h .= 1.2 -m.init .= 1.3 # init value stays at init value - rather than taking state value -m.state .= 1.4 - -params(m) -params(m)[1] -m(seq[2]) -@time m.(seq) -@time map(m, seq) - -###################### -# single array -###################### -seq = [[1,2] ./ 10 for b in 1:3] -seq = hcat(seq...) -seq = [seq for i = 1:4] -seq = cat(seq..., dims=3) - -m = RNN(2, 5) -m.cell.Wi .= [0.1 0] -m.cell.Wh .= [0.1] -m.cell.b .= 0 -# m.cell.h .= 0 -m.init .= 0.0 -m.state .= 0 - -params(m) -@time mapslices(m, seq, dims=(1,2)) -mapslices(size, seq, dims=(1,2)) - - - -###################### -# issue: https://github.com/FluxML/Flux.jl/issues/1114 -###################### -rnn = Chain(LSTM(16, 8), - Dense(8,1, σ), - x -> reshape(x,:)) - -X = [rand(16,10) for i in 1:20] -Y = rand(10,20) ./ 10 - -rnn = rnn |> gpu -X = gpu(X) -Y = gpu(Y) - -θ = Flux.params(rnn) -loss(x,y) = mean((Flux.stack(rnn.(x),2) .- y) .^ 2f0) -opt = ADAM(1e-3) -size(rnn[1].state[1]) -Flux.reset!(rnn) -size(rnn[1].state[1]) -Flux.train!(loss, θ, [(X,Y)], opt) -size(rnn[1].state[1]) -loss(X,Y) - -Flux.stack(rnn.(X),2) -rnn.(X) - -######################## -# rnn test gpu -######################## -feat = 2 -h_size = 11 -seq_len = 4 -batch_size = 3 -rnn = Chain(RNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x,:)) - -X = [rand(feat, batch_size) for i in 1:seq_len] -Y = rand(batch_size, seq_len) ./ 10 - -rnn = rnn |> gpu -X = gpu(X) -Y = gpu(Y) - -θ = Flux.params(rnn) -mapreduce(length, +, θ) - h_size -1 # num params in RNN - -function loss(x,y) - l = mean((Flux.stack(map(rnn, x),2) .- y) .^ 2f0) - Flux.reset!(rnn) - return l -end - -opt = ADAM(1e-3) -loss(X,Y) -Flux.reset!(rnn) -Flux.train!(loss, θ, [(X,Y)], opt) -loss(X,Y) -for i in 1:100 - Flux.train!(loss, θ, [(X,Y)], opt) -end -Flux.reset!(rnn) -Flux.train!(loss, θ, [(X,Y)], opt) - -θ[1] -θ[3] -θ[4] From 68abdd889c3b254146783454d4e0f861e0ceb4bb Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Sat, 31 Oct 2020 00:36:00 -0400 Subject: [PATCH 16/17] remove experiment --- test/rnn-test-jdb.jl | 57 -------------------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 test/rnn-test-jdb.jl diff --git a/test/rnn-test-jdb.jl b/test/rnn-test-jdb.jl deleted file mode 100644 index 07d04aa186..0000000000 --- a/test/rnn-test-jdb.jl +++ /dev/null @@ -1,57 +0,0 @@ -using Revise -using Flux -using Statistics: mean - -# illustrate diverging behavior of GPU execution -feat = 32 -h_size = 64 -seq_len = 20 -batch_size = 100 - -rnn = Chain(RNN(feat, h_size), - Dense(h_size, 1, σ), - x -> reshape(x, :)) - -X = [rand(Float32, feat, batch_size) for i in 1:seq_len] -Y = rand(Float32, batch_size, seq_len) ./ 10 - -#### transfer to gpu #### -rnn_gpu = rnn |> gpu -X_gpu = gpu(X) -Y_gpu = gpu(Y) - -θ = Flux.params(rnn) -θ_gpu = Flux.params(rnn_gpu) -function loss(x, y) - l = mean((Flux.stack(map(rnn, x), 2) .- y).^2f0) - Flux.reset!(rnn) - return l -end -function loss_gpu(x, y) - l = mean((Flux.stack(map(rnn_gpu, x), 2) .- y).^2f0) - Flux.reset!(rnn_gpu) - return l -end - -opt = ADAM(1e-3) -opt_gpu = ADAM(1e-3) -for i in 1:25 - println("iter: ", i) - Flux.train!(loss, θ, [(X, Y)], opt) - Flux.train!(loss_gpu, θ_gpu, [(X_gpu, Y_gpu)], opt_gpu) - println("loss_cpu: ", loss(X, Y)) - println("loss_gpu: ", loss_gpu(X_gpu, Y_gpu)) - println("θ[3][1:2]: ", θ[3][1:2]) - println("θ_gpu[3][1:2]: ", θ_gpu[3][1:2]) - println("θ[4][1:2]: ", θ[4][1:2]) - println("θ_gpu[4][1:2]: ", θ_gpu[4][1:2]) - println("rnn.layers[1].state[1:2]: ", rnn.layers[1].state[1:2]) - println("rnn_gpu.layers[1].state[1:2]: ", rnn_gpu.layers[1].state[1:2]) -end - -θ[1] -θ[2] -θ[3] -θ[4] -θ[5] -θ[6] \ No newline at end of file From 40bf4b7ceb9bf881251b9a0dce60805496310dad Mon Sep 17 00:00:00 2001 From: "jeremie.db" Date: Sat, 31 Oct 2020 01:07:12 -0400 Subject: [PATCH 17/17] fix tests --- test/cuda/curnn.jl | 8 +++----- test/utils.jl | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index 345401e0ff..a7e5fa234f 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -8,7 +8,7 @@ using Flux: pullback Flux.reset!(m) θ = gradient(() -> sum(m(x)), params(m)) @test x isa CuArray - @test_broken θ[m.cell.Wi] isa CuArray + @test θ[m.cell.Wi] isa CuArray @test_broken collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi]) end @@ -20,8 +20,8 @@ end Flux.reset!(rnn) Flux.reset!(curnn) x = batch_size == 1 ? - rand(10) : - rand(10, batch_size) + rand(Float32, 10) : + rand(Float32, 10, batch_size) cux = gpu(x) y, back = pullback((r, x) -> r(x), rnn, x) @@ -29,8 +29,6 @@ end @test y ≈ collect(cuy) - @test haskey(Flux.CUDAint.descs, curnn.cell) - ȳ = randn(size(y)) m̄, x̄ = back(ȳ) cum̄, cux̄ = cuback(gpu(ȳ)) diff --git a/test/utils.jl b/test/utils.jl index b2deed4adb..cb5c12cc2f 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -101,16 +101,16 @@ end m = Dense(10, 5) @test size.(params(m)) == [(5, 10), (5,)] m = RNN(10, 5) - @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)] + @test size.(params(m)) == [(5, 10), (5, 5), (5,), (5, 1)] # Layer duplicated in same chain, params just once pls. c = Chain(m, m) - @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)] + @test size.(params(c)) == [(5, 10), (5, 5), (5,), (5, 1)] # Self-referential array. Just want params, no stack overflow pls. r = Any[nothing,m] r[1] = r - @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)] + @test size.(params(r)) == [(5, 10), (5, 5), (5,), (5, 1)] end @testset "Basic Stacking" begin