diff --git a/src/Flux.jl b/src/Flux.jl index 2804803947..b0937a9ff3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -71,7 +71,7 @@ include("functor.jl") @compat(public, ( # from OneHotArrays.jl - onehot, onehotbatch, onecold, + onehot, onehotbatch, onecold, # from Functors.jl functor, @functor, KeyPath, haskeypath, getkeypath, # from Optimise/Train/Optimisers.jl @@ -90,6 +90,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") +include("layers/rotary.jl") include("layers/attention.jl") include("loading.jl") diff --git a/src/layers/attention.jl b/src/layers/attention.jl index d4a33283d9..1ce78bc1d2 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -14,7 +14,7 @@ Returns the transformed input sequence and the attention scores. # Arguments - `dims`: The embedding dimensions of inputs, intermediate tensors and outputs. - In the most general case, it is given as + In the most general case, it is given as a) `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`. Can take also simpler forms as b) `dims::Int`; @@ -26,22 +26,24 @@ Returns the transformed input sequence and the attention scores. - `dropout_prob`: dropout probability for the attention scores. Default `0.0`. # Forward - - (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask]) + + (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask], [rope]) The arguments of the forward pass are: - `q_in`: Input query array of size `(q_in_dim, q_len, batch_size)`. - `k_in`: Input key array of size `(k_in_dim, kv_len, batch_size)`. - `v_in`: Input value array of size `(v_in_dim, kv_len, batch_size)`. -- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. +- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. It will be added to the attention scores before the softmax. Default `nothing`. -- `mask`: Input array broadcastable to size - `(kv_len, q_len, nheads, batch_size)`. - The mask is applied to the attention scores just before the softmax. - See [`NNlib.make_causal_mask`](@ref) for creating causal masks. +- `mask`: Input array broadcastable to size + `(kv_len, q_len, nheads, batch_size)`. + The mask is applied to the attention scores just before the softmax. + See [`NNlib.make_causal_mask`](@ref) for creating causal masks. Default `nothing`. +- `rope`: Whether to apply rotary position embeddings to the input tensors. + Default `false`. Alternative calling signatures are `mha(q_in)`, equivalent to `mha(q_in, q_in, q_in)` (self-attention), and `mha(q_in, k_in)`, equivalent to `mha(q_in, k_in, k_in)` (key and value are the same). @@ -55,7 +57,7 @@ mha = MultiHeadAttention(64, nheads = 8) q = rand(Float32, (64, 10, 32)) k = rand(Float32, (64, 20, 32)) v = rand(Float32, (64, 20, 32)) -y, α = mha(q, k, v) +y, α = mha(q, k, v) # [y] = [64, 10, 32] # [α] = [20, 10, 8, 32] @@ -76,10 +78,10 @@ end @layer MultiHeadAttention -function MultiHeadAttention(dims; +function MultiHeadAttention(dims; nheads::Int = 8, bias::Bool = false, - init = glorot_uniform, + init = glorot_uniform, dropout_prob = 0.0) dims = normalize_mha_dims(dims) @@ -94,7 +96,7 @@ function MultiHeadAttention(dims; end # turns the dims argument into a named tuple -normalize_mha_dims(dims::Int) = +normalize_mha_dims(dims::Int) = (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) function normalize_mha_dims((in, (qkv, out))::Pair{<:IntOrDims{3}, <:Pair{<:IntOrDims{2}, Int}}) @@ -117,14 +119,18 @@ end # key and value are the same (mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...) -function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, - bias=nothing; mask=nothing) +function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, + bias=nothing; mask=nothing, rope=false) ## [q_in] = [q_in_dim, q_len, batch_size] - ## [k_in] = [k_in_dim, kv_len, batch_size] + ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size] - k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size] + k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size] v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size] + if rope + q = with_rotary_position_embedding(q) + k = with_rotary_position_embedding(k) + end x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop) x = mha.out_proj(x) # [x] = [out_dim, q_len, batch_size] @@ -158,7 +164,6 @@ function Base.show(io::IO, mha::MultiHeadAttention) print(io, ")") end - #= # Test cases for printing: diff --git a/src/layers/rotary.jl b/src/layers/rotary.jl new file mode 100644 index 0000000000..8c90bf68a0 --- /dev/null +++ b/src/layers/rotary.jl @@ -0,0 +1,157 @@ +""" + Rotary Position Embeddings (RoPE) + +This is a port and simplified code of the RoPE implementation from NeuralAttentionlib.jl, which is an implementation of +the Rotary Position Embeddings (RoPE) described in the RoFormer paper. + +Original sources: +- Paper: "RoFormer: Enhanced Transformer with Rotary Position Embedding" + Authors: Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen + URL: https://arxiv.org/abs/2104.09864 + +- Code: NeuralAttentionlib.jl + Author: chengchingwen + Repository: https://github.com/chengchingwen/NeuralAttentionlib.jl + +RoPE encodes absolute positional information with a rotation matrix that naturally +incorporates explicit relative position dependency in self-attention formulation. +""" + +""" +Calculate position-dependent frequency. +""" +function _default_freq_decay(i, hidden_size) + j = 8 * (1 - i) + return Float32(10^(j / hidden_size)) +end + +""" +Calculate sinusoidal position embedding. +""" +function sincos_position_embed(pos, idx, hidden_size) + feature = Int32(idx) + i = (feature + 1) >> 1 # integer divide by 2 + pos_idx = Int32(pos - 1) + + freq = _default_freq_decay(i, hidden_size) + angle = pos_idx * freq + + return iseven(feature) ? cos(angle) : sin(angle) +end + +ChainRulesCore.@non_differentiable sincos_position_embed(pos, idx, hidden_size) + +""" +Apply rotation to a pair of values. +""" +function _rotary((x1, x2), (sin_θ, cos_θ)) + return ( + x1 * cos_θ - x2 * sin_θ, + x2 * cos_θ + x1 * sin_θ + ) +end + +""" +Apply rotary embeddings to the full tensor. +""" +function _apply_rotary(x, seq_len) + hidden_size = size(x, 1) + + # Get positional encodings + pos_enc = similar(x, hidden_size, seq_len) + for i in 1:hidden_size, j in 1:seq_len + pos_enc[i,j] = sincos_position_embed(j, i, hidden_size) + end + + # Reshape to handle pairs properly + x_reshaped = reshape(x, 2, :) + pos_reshaped = reshape(pos_enc, 2, :) + + # Now reinterpret as pairs + x_pairs = reinterpret(reshape, NTuple{2,eltype(x)}, x_reshaped) + pos_pairs = reinterpret(reshape, NTuple{2,eltype(pos_enc)}, pos_reshaped) + + # Apply rotary transformation + y_reshaped = similar(x_reshaped) + y_pairs = reinterpret(reshape, NTuple{2,eltype(y_reshaped)}, y_reshaped) + + for i in axes(x_pairs, 1) + y_pairs[i] = _rotary(x_pairs[i], pos_pairs[i]) + end + + # Reshape back to original dimensions + return reshape(y_reshaped, size(x)) +end + +""" +Apply rotary position embeddings to input tensor x. +""" +function with_rotary_position_embedding(x::AbstractArray) + hidden_size = size(x, 1) + iseven(hidden_size) || throw(ArgumentError("Feature dimension ($(hidden_size)) must be even")) + return _apply_rotary(x, size(x, 2)) +end + +# Gradient rules +function ChainRulesCore.rrule(::typeof(_rotary), x_pair, pos_pair) + x1, x2 = x_pair + sin_θ, cos_θ = pos_pair + y1, y2 = _rotary(x_pair, pos_pair) + + function rotary_pullback(Ȳ) + ∂y1, ∂y2 = Ȳ + + ∂x1 = ∂y1 * cos_θ + ∂y2 * sin_θ + ∂x2 = -∂y1 * sin_θ + ∂y2 * cos_θ + + return (NoTangent(), (∂x1, ∂x2), NoTangent()) + end + + return (y1, y2), rotary_pullback +end + +function ChainRulesCore.rrule(::typeof(_apply_rotary), x, seq_len) + y = _apply_rotary(x, seq_len) + + function apply_rotary_pullback(Ȳ) + hidden_size = size(x, 1) + + # Recalculate position encodings for gradient + pos_enc = similar(x, hidden_size, seq_len) + for i in 1:hidden_size, j in 1:seq_len + pos_enc[i,j] = sincos_position_embed(j, i, hidden_size) + end + + # Reshape for gradient computation + x_reshaped = reshape(x, 2, :) + pos_reshaped = reshape(pos_enc, 2, :) + Ȳ_reshaped = reshape(Ȳ, 2, :) + + x_pairs = reinterpret(reshape, NTuple{2,eltype(x)}, x_reshaped) + pos_pairs = reinterpret(reshape, NTuple{2,eltype(pos_enc)}, pos_reshaped) + Ȳ_pairs = reinterpret(reshape, NTuple{2,eltype(Ȳ)}, Ȳ_reshaped) + + ∂x_reshaped = similar(x_reshaped) + ∂x_pairs = reinterpret(reshape, NTuple{2,eltype(∂x_reshaped)}, ∂x_reshaped) + + for i in axes(x_pairs, 1) + _, pb = rrule(_rotary, x_pairs[i], pos_pairs[i]) + ∂x_pairs[i] = pb(Ȳ_pairs[i])[2] + end + + return (NoTangent(), reshape(∂x_reshaped, size(x)), NoTangent()) + end + + return y, apply_rotary_pullback +end + +function ChainRulesCore.rrule(::typeof(with_rotary_position_embedding), x::AbstractArray) + y = with_rotary_position_embedding(x) + + function rotary_pullback(Ȳ) + _, ∂x, _ = rrule(_apply_rotary, x, size(x,2))[2](Ȳ) + return (NoTangent(), ∂x) + end + + return y, rotary_pullback +end diff --git a/src/utils.jl b/src/utils.jl index 20d16596ed..302921102d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -205,7 +205,7 @@ ChainRulesCore.@non_differentiable kaiming_normal(::Any...) """ truncated_normal([rng], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array truncated_normal([rng]; kw...) -> Function - + Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution. The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`. @@ -253,7 +253,7 @@ ChainRulesCore.@non_differentiable truncated_normal(::Any...) lecun_normal([rng]; kw...) -> Function Return an `Array{Float32}` of the given `size` containing random numbers drawn from a truncated normal -distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units +distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units in the weight tensor. # Examples @@ -414,7 +414,7 @@ Has the following behaviour * 2D: An identity matrix (useful for an identity matrix multiplication) * More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution) -Some caveats: +Some caveats: * Not all layers will be identity mapping when used with this init. Exceptions include recurrent layers and normalization layers. diff --git a/test/layers/rotary.jl b/test/layers/rotary.jl new file mode 100644 index 0000000000..8fcb98105a --- /dev/null +++ b/test/layers/rotary.jl @@ -0,0 +1,32 @@ +using Flux: with_rotary_position_embedding + +@testset "Rotary Position Embedding Tests" begin + Random.seed!(123) + test_sizes = [(2,2), (4,6), (8,10)] + + for (n, d) in test_sizes + x = randn(n, d) + test_gradients( + with_rotary_position_embedding, + x; + rtol=1e-4, + atol=1e-4, + test_gpu=false, + compare_finite_diff=true, + loss=(f, x) -> sum(f(x)) + ) + end + + # Edge cases + test_gradients( + with_rotary_position_embedding, + zeros(4, 6); + loss=(f, x) -> sum(f(x)) + ) + + test_gradients( + with_rotary_position_embedding, + ones(4, 6); + loss=(f, x) -> sum(f(x)) + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 6f5a2e7d84..983ee67eab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,6 +55,7 @@ Random.seed!(0) include("layers/upsample.jl") include("layers/show.jl") include("layers/macro.jl") + include("layers/rotary.jl") end @testset "outputsize" begin