Skip to content

Commit

Permalink
CUDA Zygote fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 26, 2024
1 parent 5836186 commit 8b1c35e
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,14 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=
xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch))
xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch))

#Lazy permute dims. Need to test CUDA.
xq = PermutedDimsArray(xq, (1,3,2,4))
xk = PermutedDimsArray(xk, (1,3,2,4))
xv = PermutedDimsArray(xv, (1,3,2,4))
#Lazy permute dims. Need to test CUDA. Note: test fails.
#xq = PermutedDimsArray(xq, (1,3,2,4))
#xk = PermutedDimsArray(xk, (1,3,2,4))
#xv = PermutedDimsArray(xv, (1,3,2,4))

xq = permutedims(xq, (1,3,2,4))
xk = permutedims(xk, (1,3,2,4))
xv = permutedims(xv, (1,3,2,4))

xq_rope = apply_rotary_emb(xq, freqs_cis)
xk_rope = apply_rotary_emb(xk, freqs_cis)
Expand Down

0 comments on commit 8b1c35e

Please sign in to comment.