From 51ca97fd2274b48aa8c7f8dd4bdc6de9b0c71921 Mon Sep 17 00:00:00 2001 From: Francesco Martinuzzi Date: Sun, 10 Nov 2024 21:03:27 +0100 Subject: [PATCH] Distinct init for kernel and recurrent (#2522) --- src/layers/recurrent.jl | 176 +++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 64 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 7c69b80103..0834ce5b3f 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -2,7 +2,8 @@ # Vanilla RNN @doc raw""" - RNNCell(in => out, σ = tanh; init = glorot_uniform, bias = true) + RNNCell(in => out, σ = tanh; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the output fed back into the input each time step. @@ -19,7 +20,8 @@ See [`RNN`](@ref) for a layer that processes entire sequences. - `in => out`: The input and output dimensions of the layer. - `σ`: The non-linearity to apply to the output. Default is `tanh`. -- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. # Forward @@ -58,18 +60,24 @@ h # The final hidden state ŷ # The hidden states at each time step ``` """ -struct RNNCell{F,I,H,V} +struct RNNCell{F, I, H, V} σ::F Wi::I Wh::H bias::V end -@layer RNNCell +@layer RNNCell -function RNNCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true) - Wi = init(out, in) - Wh = init(out, out) +function RNNCell( + (in, out)::Pair, + σ = tanh; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, +) + Wi = init_kernel(out, in) + Wh = init_recurrent_kernel(out, out) b = create_bias(Wi, bias, size(Wi, 1)) return RNNCell(σ, Wi, Wh, b) end @@ -77,9 +85,9 @@ end (m::RNNCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 1))) function (m::RNNCell)(x::AbstractVecOrMat, h::AbstractVecOrMat) - _size_check(m, x, 1 => size(m.Wi,2)) + _size_check(m, x, 1 => size(m.Wi, 2)) σ = NNlib.fast_act(m.σ, x) - h = σ.(m.Wi*x .+ m.Wh*h .+ m.bias) + h = σ.(m.Wi * x .+ m.Wh * h .+ m.bias) return h end @@ -90,7 +98,8 @@ function Base.show(io::IO, m::RNNCell) end @doc raw""" - RNN(in => out, σ = tanh; bias = true, init = glorot_uniform) + RNN(in => out, σ = tanh; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the output fed back into the input each time step. @@ -108,7 +117,8 @@ See [`RNNCell`](@ref) for a layer that processes a single time step. - `in => out`: The input and output dimensions of the layer. - `σ`: The non-linearity to apply to the output. Default is `tanh`. -- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. # Forward @@ -161,30 +171,31 @@ end @layer :expand RNN -function RNN((in, out)::Pair, σ = tanh; bias = true, init = glorot_uniform) - cell = RNNCell(in => out, σ; bias, init) +function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) + cell = RNNCell(in => out, σ; cell_kwargs...) return RNN(cell) end (m::RNN)(x) = m(x, zeros_like(x, size(m.cell.Wh, 1))) -function (m::RNN)(x, h) +function (m::RNN)(x, h) @assert ndims(x) == 2 || ndims(x) == 3 # [x] = [in, L] or [in, L, B] # [h] = [out] or [out, B] y = [] - for x_t in eachslice(x, dims=2) + for x_t in eachslice(x, dims = 2) h = m.cell(x_t, h) # y = [y..., h] y = vcat(y, [h]) end - return stack(y, dims=2) + return stack(y, dims = 2) end # LSTM @doc raw""" - LSTMCell(in => out; init = glorot_uniform, bias = true) + LSTMCell(in => out; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) The [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) cell. Behaves like an RNN but generally exhibits a longer memory span over sequences. @@ -205,7 +216,8 @@ See also [`LSTM`](@ref) for a layer that processes entire sequences. # Arguments - `in => out`: The input and output dimensions of the layer. -- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. # Forward @@ -239,7 +251,7 @@ julia> size(h′) # out x batch_size (5, 4) ``` """ -struct LSTMCell{I,H,V} +struct LSTMCell{I, H, V} Wi::I Wh::H bias::V @@ -247,9 +259,14 @@ end @layer LSTMCell -function LSTMCell((in, out)::Pair; init = glorot_uniform, bias = true) - Wi = init(out * 4, in) - Wh = init(out * 4, out) +function LSTMCell( + (in, out)::Pair; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, +) + Wi = init_kernel(out * 4, in) + Wh = init_recurrent_kernel(out * 4, out) b = create_bias(Wi, bias, out * 4) cell = LSTMCell(Wi, Wh, b) return cell @@ -265,18 +282,19 @@ function (m::LSTMCell)(x::AbstractVecOrMat, (h, c)) _size_check(m, x, 1 => size(m.Wi, 2)) b = m.bias g = m.Wi * x .+ m.Wh * h .+ b - input, forget, cell, output = chunk(g, 4; dims=1) + input, forget, cell, output = chunk(g, 4; dims = 1) c′ = @. sigmoid_fast(forget) * c + sigmoid_fast(input) * tanh_fast(cell) h′ = @. sigmoid_fast(output) * tanh_fast(c′) return h′, c′ end Base.show(io::IO, m::LSTMCell) = - print(io, "LSTMCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷4, ")") + print(io, "LSTMCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 4, ")") @doc raw"""" - LSTM(in => out; init = glorot_uniform, bias = true) + LSTM(in => out; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) recurrent layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. @@ -299,7 +317,8 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step. # Arguments - `in => out`: The input and output dimensions of the layer. -- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. # Forward @@ -342,8 +361,8 @@ end @layer :expand LSTM -function LSTM((in, out)::Pair; init = glorot_uniform, bias = true) - cell = LSTMCell(in => out; init, bias) +function LSTM((in, out)::Pair; cell_kwargs...) + cell = LSTMCell(in => out; cell_kwargs...) return LSTM(cell) end @@ -357,18 +376,19 @@ function (m::LSTM)(x, (h, c)) @assert ndims(x) == 2 || ndims(x) == 3 h′ = [] c′ = [] - for x_t in eachslice(x, dims=2) + for x_t in eachslice(x, dims = 2) h, c = m.cell(x_t, (h, c)) h′ = vcat(h′, [h]) c′ = vcat(c′, [c]) end - return stack(h′, dims=2), stack(c′, dims=2) + return stack(h′, dims = 2), stack(c′, dims = 2) end # GRU @doc raw""" - GRUCell(in => out; init = glorot_uniform, bias = true) + GRUCell(in => out; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. @@ -388,7 +408,8 @@ See also [`GRU`](@ref) for a layer that processes entire sequences. # Arguments - `in => out`: The input and output dimensions of the layer. -- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. # Forward @@ -416,7 +437,7 @@ julia> x = rand(Float32, 3, 4); # in x batch_size julia> h′ = g(x, h); ``` """ -struct GRUCell{I,H,V} +struct GRUCell{I, H, V} Wi::I Wh::H b::V @@ -424,9 +445,14 @@ end @layer GRUCell -function GRUCell((in, out)::Pair; init = glorot_uniform, bias = true) - Wi = init(out * 3, in) - Wh = init(out * 3, out) +function GRUCell( + (in, out)::Pair; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, +) + Wi = init_kernel(out * 3, in) + Wh = init_recurrent_kernel(out * 3, out) b = create_bias(Wi, bias, size(Wi, 1)) return GRUCell(Wi, Wh, b) end @@ -434,11 +460,11 @@ end (m::GRUCell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) function (m::GRUCell)(x::AbstractVecOrMat, h) - _size_check(m, x, 1 => size(m.Wi,2)) - gxs = chunk(m.Wi * x, 3, dims=1) - ghs = chunk(m.Wh * h, 3, dims=1) + _size_check(m, x, 1 => size(m.Wi, 2)) + gxs = chunk(m.Wi * x, 3, dims = 1) + ghs = chunk(m.Wh * h, 3, dims = 1) if m.b isa AbstractArray - bs = chunk(m.b, 3, dims=1) + bs = chunk(m.b, 3, dims = 1) else # b == false bs = [false, false, false] end @@ -450,10 +476,11 @@ function (m::GRUCell)(x::AbstractVecOrMat, h) end Base.show(io::IO, m::GRUCell) = - print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷3, ")") + print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")") @doc raw""" - GRU(in => out; init = glorot_uniform, bias = true) + GRU(in => out; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements @@ -470,6 +497,13 @@ h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1} for all `len` steps `t` in the input sequence. See [`GRUCell`](@ref) for a layer that processes a single time step. +# Arguments + +- `in => out`: The input and output dimensions of the layer. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. + # Forward gru(x, h) @@ -499,8 +533,8 @@ end @layer :expand GRU -function GRU((in, out)::Pair; init = glorot_uniform, bias = true) - cell = GRUCell(in => out; init, bias) +function GRU((in, out)::Pair; cell_kwargs...) + cell = GRUCell(in => out; cell_kwargs...) return GRU(cell) end @@ -513,16 +547,17 @@ function (m::GRU)(x, h) @assert ndims(x) == 2 || ndims(x) == 3 h′ = [] # [x] = [in, L] or [in, L, B] - for x_t in eachslice(x, dims=2) + for x_t in eachslice(x, dims = 2) h = m.cell(x_t, h) h′ = vcat(h′, [h]) end - return stack(h′, dims=2) + return stack(h′, dims = 2) end # GRU v3 @doc raw""" - GRUv3Cell(in => out, init = glorot_uniform, bias = true) + GRUv3Cell(in => out; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. @@ -543,7 +578,8 @@ See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. # Arguments - `in => out`: The input and output dimensions of the layer. -- `init`: The initialization function to use for the weights. Default is `glorot_uniform`. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. # Forward @@ -558,7 +594,7 @@ The arguments of the forward pass are: Returns the new hidden state `h'` as an array of size `out` or `out x batch_size`. """ -struct GRUv3Cell{I,H,V,HH} +struct GRUv3Cell{I, H, V, HH} Wi::I Wh::H b::V @@ -567,10 +603,15 @@ end @layer GRUv3Cell -function GRUv3Cell((in, out)::Pair; init = glorot_uniform, bias = true) - Wi = init(out * 3, in) - Wh = init(out * 3, out) - Wh_h̃ = init(out, out) +function GRUv3Cell( + (in, out)::Pair; + init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, + bias = true, +) + Wi = init_kernel(out * 3, in) + Wh = init_recurrent_kernel(out * 3, out) + Wh_h̃ = init_recurrent_kernel(out, out) b = create_bias(Wi, bias, out * 3) return GRUv3Cell(Wi, Wh, b, Wh_h̃) end @@ -578,11 +619,11 @@ end (m::GRUv3Cell)(x::AbstractVecOrMat) = m(x, zeros_like(x, size(m.Wh, 2))) function (m::GRUv3Cell)(x::AbstractVecOrMat, h) - _size_check(m, x, 1 => size(m.Wi,2)) - gxs = chunk(m.Wi * x, 3, dims=1) - ghs = chunk(m.Wh * h, 3, dims=1) + _size_check(m, x, 1 => size(m.Wi, 2)) + gxs = chunk(m.Wi * x, 3, dims = 1) + ghs = chunk(m.Wh * h, 3, dims = 1) if m.b isa AbstractArray - bs = chunk(m.b, 3, dims=1) + bs = chunk(m.b, 3, dims = 1) else # m.b == false bs = [false, false, false] end @@ -594,11 +635,12 @@ function (m::GRUv3Cell)(x::AbstractVecOrMat, h) end Base.show(io::IO, m::GRUv3Cell) = - print(io, "GRUv3Cell(", size(m.Wi, 2), " => ", size(m.Wi, 1)÷3, ")") + print(io, "GRUv3Cell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")") @doc raw""" - GRUv3(in => out) + GRUv3(in => out; init_kernel = glorot_uniform, + init_recurrent_kernel = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an RNN but generally exhibits a longer memory span over sequences. This implements @@ -615,6 +657,13 @@ h_t = (1 - z_t) \odot h̃_t + z_t \odot h_{t-1} for all `len` steps `t` in the input sequence. See [`GRUv3Cell`](@ref) for a layer that processes a single time step. See [`GRU`](@ref) and [`GRUCell`](@ref) for variants of this layer. + +# Arguments + +- `in => out`: The input and output dimensions of the layer. +- `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. +- `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. +- `bias`: Whether to include a bias term initialized to zero. Default is `true`. """ struct GRUv3{M} cell::M @@ -622,8 +671,8 @@ end @layer :expand GRUv3 -function GRUv3((in, out)::Pair; init = glorot_uniform, bias = true) - cell = GRUv3Cell(in => out; init, bias) +function GRUv3((in, out)::Pair; cell_kwargs...) + cell = GRUv3Cell(in => out; cell_kwargs...) return GRUv3(cell) end @@ -635,10 +684,9 @@ end function (m::GRUv3)(x, h) @assert ndims(x) == 2 || ndims(x) == 3 h′ = [] - for x_t in eachslice(x, dims=2) + for x_t in eachslice(x, dims = 2) h = m.cell(x_t, h) h′ = vcat(h′, [h]) end - return stack(h′, dims=2) + return stack(h′, dims = 2) end -