From 8b1c35eb9f618799bafa1e3501af74bd30a7e0ac Mon Sep 17 00:00:00 2001 From: murrellb Date: Tue, 26 Nov 2024 21:58:31 +0100 Subject: [PATCH] CUDA Zygote fix. --- src/layers.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/layers.jl b/src/layers.jl index 130d4b5..1ca53ae 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -84,10 +84,14 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Lazy permute dims. Need to test CUDA. - xq = PermutedDimsArray(xq, (1,3,2,4)) - xk = PermutedDimsArray(xk, (1,3,2,4)) - xv = PermutedDimsArray(xv, (1,3,2,4)) + #Lazy permute dims. Need to test CUDA. Note: test fails. + #xq = PermutedDimsArray(xq, (1,3,2,4)) + #xk = PermutedDimsArray(xk, (1,3,2,4)) + #xv = PermutedDimsArray(xv, (1,3,2,4)) + + xq = permutedims(xq, (1,3,2,4)) + xk = permutedims(xk, (1,3,2,4)) + xv = permutedims(xv, (1,3,2,4)) xq_rope = apply_rotary_emb(xq, freqs_cis) xk_rope = apply_rotary_emb(xk, freqs_cis)