Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Rotary Positional Embedding from NeuralAttentionlib.jl #2524

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ include("functor.jl")

@compat(public, (
# from OneHotArrays.jl
onehot, onehotbatch, onecold,
onehot, onehotbatch, onecold,
# from Functors.jl
functor, @functor, KeyPath, haskeypath, getkeypath,
# from Optimise/Train/Optimisers.jl
Expand All @@ -90,6 +90,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/rotary.jl")
include("layers/attention.jl")

include("loading.jl")
Expand Down
39 changes: 22 additions & 17 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Returns the transformed input sequence and the attention scores.
# Arguments

- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
In the most general case, it is given as
In the most general case, it is given as
a) `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
Can take also simpler forms as
b) `dims::Int`;
Expand All @@ -26,22 +26,24 @@ Returns the transformed input sequence and the attention scores.
- `dropout_prob`: dropout probability for the attention scores. Default `0.0`.

# Forward
(mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask])

(mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask], [rope])

The arguments of the forward pass are:

- `q_in`: Input query array of size `(q_in_dim, q_len, batch_size)`.
- `k_in`: Input key array of size `(k_in_dim, kv_len, batch_size)`.
- `v_in`: Input value array of size `(v_in_dim, kv_len, batch_size)`.
- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
It will be added to the attention scores before the softmax.
Default `nothing`.
- `mask`: Input array broadcastable to size
`(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores just before the softmax.
See [`NNlib.make_causal_mask`](@ref) for creating causal masks.
- `mask`: Input array broadcastable to size
`(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores just before the softmax.
See [`NNlib.make_causal_mask`](@ref) for creating causal masks.
Default `nothing`.
- `rope`: Whether to apply rotary position embeddings to the input tensors.
Default `false`.

Alternative calling signatures are `mha(q_in)`, equivalent to `mha(q_in, q_in, q_in)` (self-attention),
and `mha(q_in, k_in)`, equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).
Expand All @@ -55,7 +57,7 @@ mha = MultiHeadAttention(64, nheads = 8)
q = rand(Float32, (64, 10, 32))
k = rand(Float32, (64, 20, 32))
v = rand(Float32, (64, 20, 32))
y, α = mha(q, k, v)
y, α = mha(q, k, v)
# [y] = [64, 10, 32]
# [α] = [20, 10, 8, 32]

Expand All @@ -76,10 +78,10 @@ end

@layer MultiHeadAttention

function MultiHeadAttention(dims;
function MultiHeadAttention(dims;
nheads::Int = 8,
bias::Bool = false,
init = glorot_uniform,
init = glorot_uniform,
dropout_prob = 0.0)

dims = normalize_mha_dims(dims)
Expand All @@ -94,7 +96,7 @@ function MultiHeadAttention(dims;
end

# turns the dims argument into a named tuple
normalize_mha_dims(dims::Int) =
normalize_mha_dims(dims::Int) =
(; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims)

function normalize_mha_dims((in, (qkv, out))::Pair{<:IntOrDims{3}, <:Pair{<:IntOrDims{2}, Int}})
Expand All @@ -117,14 +119,18 @@ end
# key and value are the same
(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...)

function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
bias=nothing; mask=nothing)
function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
bias=nothing; mask=nothing, rope=false)
## [q_in] = [q_in_dim, q_len, batch_size]
## [k_in] = [k_in_dim, kv_len, batch_size]
## [k_in] = [k_in_dim, kv_len, batch_size]
## [v_in] = [v_in_dim, kv_len, batch_size]
q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size]
k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size]
k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size]
v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size]
if rope
q = with_rotary_position_embedding(q)
k = with_rotary_position_embedding(k)
end
x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop)
x = mha.out_proj(x)
# [x] = [out_dim, q_len, batch_size]
Expand Down Expand Up @@ -158,7 +164,6 @@ function Base.show(io::IO, mha::MultiHeadAttention)
print(io, ")")
end


#=

# Test cases for printing:
Expand Down
157 changes: 157 additions & 0 deletions src/layers/rotary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Rotary Position Embeddings (RoPE)

This is a port and simplified code of the RoPE implementation from NeuralAttentionlib.jl, which is an implementation of
the Rotary Position Embeddings (RoPE) described in the RoFormer paper.

Original sources:
- Paper: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
Authors: Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen
URL: https://arxiv.org/abs/2104.09864

- Code: NeuralAttentionlib.jl
Author: chengchingwen
Repository: https://github.com/chengchingwen/NeuralAttentionlib.jl

RoPE encodes absolute positional information with a rotation matrix that naturally
incorporates explicit relative position dependency in self-attention formulation.
"""

"""
Calculate position-dependent frequency.
"""
function _default_freq_decay(i, hidden_size)
j = 8 * (1 - i)
return Float32(10^(j / hidden_size))
end

"""
Calculate sinusoidal position embedding.
"""
function sincos_position_embed(pos, idx, hidden_size)
feature = Int32(idx)
i = (feature + 1) >> 1 # integer divide by 2
pos_idx = Int32(pos - 1)

freq = _default_freq_decay(i, hidden_size)
angle = pos_idx * freq

return iseven(feature) ? cos(angle) : sin(angle)
end

ChainRulesCore.@non_differentiable sincos_position_embed(pos, idx, hidden_size)

"""
Apply rotation to a pair of values.
"""
function _rotary((x1, x2), (sin_θ, cos_θ))
return (
x1 * cos_θ - x2 * sin_θ,
x2 * cos_θ + x1 * sin_θ
)
end

"""
Apply rotary embeddings to the full tensor.
"""
function _apply_rotary(x, seq_len)
hidden_size = size(x, 1)

# Get positional encodings
pos_enc = similar(x, hidden_size, seq_len)
for i in 1:hidden_size, j in 1:seq_len
pos_enc[i,j] = sincos_position_embed(j, i, hidden_size)
end

# Reshape to handle pairs properly
x_reshaped = reshape(x, 2, :)
pos_reshaped = reshape(pos_enc, 2, :)

# Now reinterpret as pairs
x_pairs = reinterpret(reshape, NTuple{2,eltype(x)}, x_reshaped)
pos_pairs = reinterpret(reshape, NTuple{2,eltype(pos_enc)}, pos_reshaped)

# Apply rotary transformation
y_reshaped = similar(x_reshaped)
y_pairs = reinterpret(reshape, NTuple{2,eltype(y_reshaped)}, y_reshaped)

for i in axes(x_pairs, 1)
y_pairs[i] = _rotary(x_pairs[i], pos_pairs[i])
end

# Reshape back to original dimensions
return reshape(y_reshaped, size(x))
end

"""
Apply rotary position embeddings to input tensor x.
"""
function with_rotary_position_embedding(x::AbstractArray)
hidden_size = size(x, 1)
iseven(hidden_size) || throw(ArgumentError("Feature dimension ($(hidden_size)) must be even"))
return _apply_rotary(x, size(x, 2))
end

# Gradient rules
function ChainRulesCore.rrule(::typeof(_rotary), x_pair, pos_pair)
x1, x2 = x_pair
sin_θ, cos_θ = pos_pair
y1, y2 = _rotary(x_pair, pos_pair)

function rotary_pullback(Ȳ)
∂y1, ∂y2 = Ȳ

∂x1 = ∂y1 * cos_θ + ∂y2 * sin_θ
∂x2 = -∂y1 * sin_θ + ∂y2 * cos_θ

return (NoTangent(), (∂x1, ∂x2), NoTangent())
end

return (y1, y2), rotary_pullback
end

function ChainRulesCore.rrule(::typeof(_apply_rotary), x, seq_len)
y = _apply_rotary(x, seq_len)

function apply_rotary_pullback(Ȳ)
hidden_size = size(x, 1)

# Recalculate position encodings for gradient
pos_enc = similar(x, hidden_size, seq_len)
for i in 1:hidden_size, j in 1:seq_len
pos_enc[i,j] = sincos_position_embed(j, i, hidden_size)
end

# Reshape for gradient computation
x_reshaped = reshape(x, 2, :)
pos_reshaped = reshape(pos_enc, 2, :)
Ȳ_reshaped = reshape(Ȳ, 2, :)

x_pairs = reinterpret(reshape, NTuple{2,eltype(x)}, x_reshaped)
pos_pairs = reinterpret(reshape, NTuple{2,eltype(pos_enc)}, pos_reshaped)
Ȳ_pairs = reinterpret(reshape, NTuple{2,eltype(Ȳ)}, Ȳ_reshaped)

∂x_reshaped = similar(x_reshaped)
∂x_pairs = reinterpret(reshape, NTuple{2,eltype(∂x_reshaped)}, ∂x_reshaped)

for i in axes(x_pairs, 1)
_, pb = rrule(_rotary, x_pairs[i], pos_pairs[i])
∂x_pairs[i] = pb(Ȳ_pairs[i])[2]
end

return (NoTangent(), reshape(∂x_reshaped, size(x)), NoTangent())
end

return y, apply_rotary_pullback
end

function ChainRulesCore.rrule(::typeof(with_rotary_position_embedding), x::AbstractArray)
y = with_rotary_position_embedding(x)

function rotary_pullback(Ȳ)
_, ∂x, _ = rrule(_apply_rotary, x, size(x,2))[2](Ȳ)
return (NoTangent(), ∂x)
end

return y, rotary_pullback
end
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ ChainRulesCore.@non_differentiable kaiming_normal(::Any...)
"""
truncated_normal([rng], size...; mean = 0, std = 1, lo = -2, hi = 2) -> Array
truncated_normal([rng]; kw...) -> Function

Return an `Array{Float32}` of the given `size` where each element is drawn from a truncated normal distribution.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* randn(100))`.

Expand Down Expand Up @@ -253,7 +253,7 @@ ChainRulesCore.@non_differentiable truncated_normal(::Any...)
lecun_normal([rng]; kw...) -> Function

Return an `Array{Float32}` of the given `size` containing random numbers drawn from a truncated normal
distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units
distribution centered on 0 with stddev `sqrt(1 / fan_in)`, where `fan_in` is the number of input units
in the weight tensor.

# Examples
Expand Down Expand Up @@ -414,7 +414,7 @@ Has the following behaviour
* 2D: An identity matrix (useful for an identity matrix multiplication)
* More than 2D: A dense block array of center tap spatial filters (useful for an identity convolution)

Some caveats:
Some caveats:
* Not all layers will be identity mapping when used with this init. Exceptions
include recurrent layers and normalization layers.

Expand Down
32 changes: 32 additions & 0 deletions test/layers/rotary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Flux: with_rotary_position_embedding

@testset "Rotary Position Embedding Tests" begin
Random.seed!(123)
test_sizes = [(2,2), (4,6), (8,10)]

for (n, d) in test_sizes
x = randn(n, d)
test_gradients(
with_rotary_position_embedding,
x;
rtol=1e-4,
atol=1e-4,
test_gpu=false,
compare_finite_diff=true,
loss=(f, x) -> sum(f(x))
)
end

# Edge cases
test_gradients(
with_rotary_position_embedding,
zeros(4, 6);
loss=(f, x) -> sum(f(x))
)

test_gradients(
with_rotary_position_embedding,
ones(4, 6);
loss=(f, x) -> sum(f(x))
)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Random.seed!(0)
include("layers/upsample.jl")
include("layers/show.jl")
include("layers/macro.jl")
include("layers/rotary.jl")
end

@testset "outputsize" begin
Expand Down
Loading