Skip to content

Commit

Permalink
Fixes, conditional caching
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Dec 4, 2024
1 parent 0189aa9 commit aecca48
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
1 change: 0 additions & 1 deletion src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using StatsBase
using NNlib
using LogitSamplers
using LowRankLayers
using ReactantCore

using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer

Expand Down
26 changes: 14 additions & 12 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
mutable struct KVCache{T,A<:AbstractArray{T,4}}
head_dim::Int
n_kv_heads::Int
seq_length::Int
batch_size::Int
cache_k::A
cache_v::A
end

Flux.@layer KVCache

head_dim(cache::KVCache) = size(cache.cache_k, 1)
seq_length(cache::KVCache) = size(cache.cache_k, 2)
n_kv_heads(cache::KVCache) = size(cache.cache_k, 3)
batch_size(cache::KVCache) = size(cache.cache_k, 4)

Check warning on line 11 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L8-L11

Added lines #L8 - L11 were not covered by tests

function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1)
cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size)
cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size)
return KVCache(head_dim, n_kv_heads, seq_length, batch_size, cache_k, cache_v)
return KVCache(cache_k, cache_v)

Check warning on line 16 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L13-L16

Added lines #L13 - L16 were not covered by tests
end

function config!(cache::KVCache; seq_length=cache.seq_length, batch_size=cache.batch_size)
cache.cache_k = similar(cache.cache_k, cache.head_dim, seq_length, cache.n_kv_heads, batch_size) .= 0
cache.cache_v = similar(cache.cache_v, cache.head_dim, seq_length, cache.n_kv_heads, batch_size) .= 0
function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_size(cache))
cache.cache_k = similar(cache.cache_k, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0
cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0

Check warning on line 21 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L19-L21

Added lines #L19 - L21 were not covered by tests
end

clear!(cache::KVCache) = config!(cache, seq_length=0)

Check warning on line 24 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L24

Added line #L24 was not covered by tests

function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray)
#if iszero(cache.seq_length)
# return xk, xv
#else
if iszero(seq_length(cache))
println("fuck")
return xk, xv

Check warning on line 29 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L26-L29

Added lines #L26 - L29 were not covered by tests
else
seqlen = size(xk, 2)
cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk
cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv
return cache.cache_k[:, 1:start_pos+seqlen, :, :],

Check warning on line 34 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L31-L34

Added lines #L31 - L34 were not covered by tests
cache.cache_v[:, 1:start_pos+seqlen, :, :]
#end
end
end

0 comments on commit aecca48

Please sign in to comment.