Skip to content

Commit

Permalink
src/PositionalEmbeddings.jl: Sanitize docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
mashu committed Dec 4, 2024
1 parent be40e36 commit a6a2265
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions src/PositionalEmbeddings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,12 @@ where the upper triangle including diagonal is masked (True) and the lower trian
- 3D boolean array of shape (seq_len, seq_len, 1) where True indicates positions to mask
# Examples
```jldoctest
```
julia> mask = create_causal_mask(3)[:,:,1]
3×3 Matrix{Bool}:
1 1 1 # First position can't attend anything
0 1 1 # Second position can attend to first only
0 0 1 # Third position can attend to first and second
# Used with attention_mask:
julia> attention_mask(mask)[:,:,1]
3×3 Matrix{Float64}:
-Inf -Inf -Inf
0.0 -Inf -Inf
0.0 0.0 -Inf
```
"""
function causal_mask(seq_len::Int)
Expand All @@ -204,24 +197,18 @@ Create padding masks for batched sequences of varying lengths. This ensures that
- 3D boolean array of shape (batch_size, max_len, 1) where True indicates padded positions
# Examples
```jldoctest
```
# For 2 sequences of lengths 2 and 3, padded to length 4:
julia> mask = create_padding_mask([2, 3], 4)[:,:,1]
2×4 Matrix{Bool}:
0 0 1 1 # First sequence: length 2, positions 3-4 are padding
0 0 0 1 # Second sequence: length 3, position 4 is padding
# Combined with attention_mask:
julia> attention_mask(mask)[:,:,1]
2×4 Matrix{Float64}:
0.0 0.0 -Inf -Inf
0.0 0.0 0.0 -Inf
```
# Usage with Causal Mask
Padding and causal masks are often combined for batched autoregressive tasks:
```julia
```
seq_len = 5
batch_lengths = [3, 4]
Expand All @@ -231,7 +218,6 @@ padding = create_padding_mask(batch_lengths, seq_len) # Shape: (2, 5, 1)
# Combine masks which will either prevent attending to future tokens or padding tokens
combined = causal .| padding
final_mask = attention_mask(combined)
# final_mask will prevent:
# 1. Attending to future tokens (from causal mask)
Expand Down

0 comments on commit a6a2265

Please sign in to comment.