diff --git a/src/layers.jl b/src/layers.jl index 130d4b5..1ca53ae 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -84,10 +84,14 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Lazy permute dims. Need to test CUDA. - xq = PermutedDimsArray(xq, (1,3,2,4)) - xk = PermutedDimsArray(xk, (1,3,2,4)) - xv = PermutedDimsArray(xv, (1,3,2,4)) + #Lazy permute dims. Need to test CUDA. Note: test fails. + #xq = PermutedDimsArray(xq, (1,3,2,4)) + #xk = PermutedDimsArray(xk, (1,3,2,4)) + #xv = PermutedDimsArray(xv, (1,3,2,4)) + + xq = permutedims(xq, (1,3,2,4)) + xk = permutedims(xk, (1,3,2,4)) + xv = permutedims(xv, (1,3,2,4)) xq_rope = apply_rotary_emb(xq, freqs_cis) xk_rope = apply_rotary_emb(xk, freqs_cis)