Skip to content

Commit

Permalink
test/layers/rotary.jl: Add gradient_test with_rotary_position_embeddi…
Browse files Browse the repository at this point in the history
…ng(...)
  • Loading branch information
mashu committed Nov 14, 2024
1 parent 5abe0e8 commit 9ca7d32
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/layers/rotary.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""
Rotary Position Embeddings (RoPE)
This is a port of the RoPE implementation from NeuralAttentionlib.jl, which is an implementation of
This is a port and simplified code of the RoPE implementation from NeuralAttentionlib.jl, which is an implementation of
the Rotary Position Embeddings (RoPE) described in the RoFormer paper.
Original sources:
- Paper: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
Authors: Jianlin Su, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen
URL: https://arxiv.org/abs/2104.09864
- Code: NeuralAttentionlib.jl
Author: chengchingwen
Repository: https://github.com/chengchingwen/NeuralAttentionlib.jl
RoPE encodes absolute positional information with a rotation matrix that naturally
RoPE encodes absolute positional information with a rotation matrix that naturally
incorporates explicit relative position dependency in self-attention formulation.
"""

Expand Down
32 changes: 32 additions & 0 deletions test/layers/rotary.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Flux: with_rotary_position_embedding

@testset "Rotary Position Embedding Tests" begin
Random.seed!(123)
test_sizes = [(2,2), (4,6), (8,10)]

for (n, d) in test_sizes
x = randn(n, d)
test_gradients(
with_rotary_position_embedding,
x;
rtol=1e-4,
atol=1e-4,
test_gpu=false,
compare_finite_diff=true,
loss=(f, x) -> sum(f(x))
)
end

# Edge cases
test_gradients(
with_rotary_position_embedding,
zeros(4, 6);
loss=(f, x) -> sum(f(x))
)

test_gradients(
with_rotary_position_embedding,
ones(4, 6);
loss=(f, x) -> sum(f(x))
)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Random.seed!(0)
include("layers/upsample.jl")
include("layers/show.jl")
include("layers/macro.jl")
include("layers/rotary.jl")
end

@testset "outputsize" begin
Expand Down

0 comments on commit 9ca7d32

Please sign in to comment.