Skip to content

Commit

Permalink
test/runtests.jl: Add missing tests for positional masks.
Browse files Browse the repository at this point in the history
  • Loading branch information
mashu committed Dec 4, 2024
1 parent a6a2265 commit 6827017
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/PositionalEmbeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ julia> mask = create_causal_mask(3)[:,:,1]
0 0 1 # Third position can attend to first and second
```
"""
function causal_mask(seq_len::Int)
function create_causal_mask(seq_len::Int)
return reshape(triu(trues(seq_len, seq_len), 0), seq_len, seq_len, 1)
end

Expand Down
28 changes: 28 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,32 @@ end
@test size(output) == (seq_len, head_size, batch_size)
@test output[1:5, 1:5, 1] expected rtol=1e-5
end
end

@testset "Mask Functions" begin
@testset "create_causal_mask" begin
# Test a simple 3x3 causal mask
result = create_causal_mask(3)
expected = reshape([
1 1 1
0 1 1
0 0 1
], 3, 3, 1)
@test result == expected
@test size(result) == (3, 3, 1)
end

@testset "create_padding_mask" begin
# Test with sequences of different lengths
lengths = [2, 3, 1]
max_len = 4
result = create_padding_mask(lengths, max_len)
expected = reshape(permutedims([
0 0 1 1
0 0 0 1
0 1 1 1]
,(2,1)),(1,4,3))
@test result == expected
@test size(result) == (1, 4, 3)
end
end

2 comments on commit 6827017

@mashu
Copy link
Owner Author

@mashu mashu commented on 6827017 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Optimize for shape head_size x seq_len x (nbatch*batch_size), update docs and examples. Provide working RoPEMultiHeadAttention example in docs that does not depend on Flux.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120664

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.0 -m "<description of version>" 682701786116a59b21881609954467d95d866682
git push origin v0.6.0

Please sign in to comment.