From c5b77ff0ca5c8c7b6e65bdc23c7906f110af2d5a Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Tue, 3 Dec 2024 03:44:06 +0100 Subject: [PATCH 1/6] Reactant compatibility --- Project.toml | 6 +- README.md | 2 +- ext/MetalExt.jl | 6 +- src/Jjama3.jl | 71 +++++++++-------- src/layers.jl | 203 ++++++++++++++++++++++++++++-------------------- src/model.jl | 176 +++++------------------------------------ src/sampling.jl | 46 +++++------ 7 files changed, 205 insertions(+), 305 deletions(-) diff --git a/Project.toml b/Project.toml index b45025b..9d797c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Jjama3" uuid = "1285d783-1a6d-4703-8f05-8ac83ef55592" authors = ["murrellb and contributors"] -version = "1.0.0-DEV" +version = "1.1.0-DEV" [deps] BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogitSamplers = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" LowRankLayers = "b66182ab-a85c-43b0-99bd-d85cc47c5e50" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -20,7 +21,6 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [sources] HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"} -LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"} LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"} [extensions] @@ -30,9 +30,11 @@ MetalExt = "Metal" BytePairEncoding = "0.5" Distributions = "0.25" Flux = "0.14" +LogitSamplers = "0.1" LowRankLayers = "1.0.0" Metal = "1" NNlib = "0.9" +ReactantCore = "0.1.2" SafeTensors = "1" StatsBase = "0.34" julia = "1.11" diff --git a/README.md b/README.md index 22e9fe3..1576d93 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ generate(model, prompt, - RoPE scaling (for exceeding the model's max training-time context length) is implemented, but likely incorrect with KV cache. Be careful if you're using with really long sequences. - Imported models are trainable (with Flux), including with low-rank (ie. LoRA) finetuning. - Sampling, training, etc compatible with CUDA, where everything is much faster. -- Metal acceleration for forward_inference, forward_loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU. +- Metal acceleration for forward inference, forward loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU. ## Samplers diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 5383601..85ce17d 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -1,7 +1,9 @@ module MetalExt -#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling. -#It seems that each single Metal call has some constant overhead that kills it. +# See https://github.com/FluxML/NNlib.jl/pull/614 + +# Note: Metal speeds things up a little for forward inference and forward_loss calls, but is VERY slow for sampling. +# It seems that each single Metal call has some constant overhead that kills it. using Metal, NNlib diff --git a/src/Jjama3.jl b/src/Jjama3.jl index f0db5b6..06976ef 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,44 +1,51 @@ module Jjama3 -using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib -using LogitSamplers, LowRankLayers -import HuggingFaceTokenizers - +using Flux +using SafeTensors +using Distributions +using LinearAlgebra +using StatsBase +using NNlib +using LogitSamplers +using LowRankLayers +using ReactantCore + +using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained const tokenizer_from_file = HuggingFaceTokenizers.from_file -const Tokenizer = HuggingFaceTokenizers.Tokenizer - -const top_pk_sampler = LogitSamplers.top_pk_sampler -const argmax_sampler = LogitSamplers.argmax_sampler -const min_p_sampler = LogitSamplers.min_p_sampler -const top_nσ_sampler = LogitSamplers.top_nσ_sampler - - include("layers.jl") +export KVCache +export FeedForward +export RMSNorm +export TransformerBlock +export Transformer +export RoPE + include("model.jl") -include("utils.jl") +export forward_loss +export forward_inference + include("sampling.jl") +export top_pk_sampler +export argmax_sampler +export top_nσ_sampler +export min_p_sampler +export generate +export tokenizer_from_repo +export tokenizer_from_file +export Tokenizer -export load_llama321B_from_safetensors, - load_llama3_from_safetensors, - generate, - forward_loss, - forward_inference, - top_pk_sampler, - argmax_sampler, - top_nσ_sampler, - min_p_sampler, - tokenizer_from_repo, - tokenizer_from_file, - encode, - decode, - Tokenizer, - llama3_instruct_prompt, - llama3_assistant_prompt, - smollm2_instruct_prompt, - smollm2_assistant_prompt, - structured_choice +include("utils.jl") +export encode +export decode +export load_llama321B_from_safetensors +export load_llama3_from_safetensors +export llama3_instruct_prompt +export llama3_assistant_prompt +export smollm2_instruct_prompt +export smollm2_assistant_prompt +export structured_choice end diff --git a/src/layers.jl b/src/layers.jl index 73fd654..2c3495a 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -1,6 +1,8 @@ +const AnyDense = Union{Dense, LoRADense} + struct KVCache{T} - cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) - cache_v::AbstractArray{T, 4} + cache_k::AbstractArray{T,4} # (head_dim, seq_len, n_kv_heads, batch) + cache_v::AbstractArray{T,4} end function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) @@ -9,10 +11,11 @@ function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim: KVCache(cache_k, cache_v) end -struct FeedForward - w1::Union{Dense, LoRADense} - w2::Union{Dense, LoRADense} - w3::Union{Dense, LoRADense} + +struct FeedForward{W<:AnyDense} + w1::W + w2::W + w3::W end function FeedForward(dim::Int, ff_hidden_dim::Int) @@ -23,38 +26,92 @@ function FeedForward(dim::Int, ff_hidden_dim::Int) ) end -function (ff::FeedForward)(x) - return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -end +(ff::FeedForward)(x) = ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) Flux.@layer :expand FeedForward -struct RMSNorm{T} - weight::AbstractVector{T} + +struct RMSNorm{T,W<:AbstractVector{T}} + weight::W eps::T end -function RMSNorm(dim::Int; eps::T=1f-5) where T - RMSNorm{T}(ones(T, dim), eps) -end +RMSNorm(dim::Int; eps::T=1f-5) where {T} = RMSNorm(ones(T, dim), eps) function (norm::RMSNorm)(x) - rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) + rms = sqrt.(sum(abs2, x, dims=1) ./ size(x,1) .+ norm.eps) return x .* (norm.weight ./ rms) end Flux.@layer RMSNorm -mutable struct Attention - wq::Union{Dense, LoRADense} - wk::Union{Dense, LoRADense} - wv::Union{Dense, LoRADense} - wo::Union{Dense, LoRADense} + +struct RoPE{A<:AbstractArray} + cos::A + sin::A +end + +Base.getindex(rope::RoPE, i) = RoPE(selectdim(rope.cos, 2, i), selectdim(rope.sin, 2, i)) + +function apply_scaling!(freqs::AbstractVector; scale_factor=8) + #Hard-coded - I should move these to the main model struct and grab them from the config. + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + ### + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + for (i, freq) in enumerate(freqs) + wavelen = 2π / freq + if wavelen > low_freq_wavelen + freqs[i] = freq / scale_factor + elseif wavelen > high_freq_wavelen + @assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor) + freqs[i] = (1 - smooth) * freq / scale_factor + smooth * freq + end + end + return freqs +end + +function RoPE( + dim::Int, end_pos::Int; + theta::T=10000f0, use_scaled=true, scale_factor=8, +) where T + freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim)) + use_scaled && apply_scaling!(freqs; scale_factor) + freqs_complex = cis.(T.(0:end_pos-1) * freqs') + cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len) + sin = permutedims(imag(freqs_complex), (2, 1)) + cos = reshape(cos, (dim÷2, end_pos, 1, 1)) + sin = reshape(sin, (dim÷2, end_pos, 1, 1)) + return RoPE(cos, sin) +end + +#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +#Use this one if you're using the Hugging Face weights. +function (rope::RoPE)(x) + head_dim = size(x, 1) + x1 = @view x[1:head_dim÷2, :, :, :] + x2 = @view x[head_dim÷2+1:end, :, :, :] + return vcat( + x1 .* rope.cos .- x2 .* rope.sin, + x2 .* rope.cos .+ x1 .* rope.sin + ) +end + + +mutable struct Attention{Q,K,V,O} + wq::Q + wk::K + wv::V + wo::O n_heads::Int n_kv_heads::Int head_dim::Int n_rep::Int - cache::Union{Nothing, KVCache} + #cache::Union{Nothing, KVCache} end function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) @@ -69,11 +126,11 @@ function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) n_kv_heads, head_dim, n_rep, - nothing + #nothing ) end -function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T +function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, rope::RoPE, mask=false) where T dim, seqlen, batch = size(x) xq = attn.wq(x) @@ -84,48 +141,25 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Lazy permute dims. Need to test CUDA. Note: test fails. - #xq = PermutedDimsArray(xq, (1,3,2,4)) - #xk = PermutedDimsArray(xk, (1,3,2,4)) - #xv = PermutedDimsArray(xv, (1,3,2,4)) - xq = permutedims(xq, (1,3,2,4)) xk = permutedims(xk, (1,3,2,4)) xv = permutedims(xv, (1,3,2,4)) - xq_rope = apply_rotary_emb(xq, freqs_cis) - xk_rope = apply_rotary_emb(xk, freqs_cis) - - if !isnothing(attn.cache) - xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) - end + xq_rope = rope(xq) + xk_rope = rope(xk) + #@trace if !isnothing(attn.cache) + # xk_rope, xv = update_kv_cache!(attn.cache, start_pos, xk_rope, xv) + #end xk_rope = repeat_kv(xk_rope, attn.n_rep) xv = repeat_kv(xv, attn.n_rep) xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - - #= - scores = batched_mul( - permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) - #batched_transpose(xq_for_attn), # (seqlen, head_dim, batch*heads) - xk_for_attn # (head_dim, seqlen, batch*heads) - ) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end - sm_scores = softmax(scores; dims=2) - output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) - e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) - p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) - =# - - scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end + + scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) / sqrt(T(attn.head_dim)) + scores .+= mask sm_scores = softmax(scores; dims=1) output = batched_mul(xv_for_attn, sm_scores) e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) @@ -136,13 +170,14 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= return proj end -Flux.@layer :expand Attention trainable=(wq, wv) +Flux.@layer :expand Attention trainable=(wq,wv) + -struct TransformerBlock - attention::Attention - feed_forward::FeedForward - attention_norm::RMSNorm - ffn_norm::RMSNorm +struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} + attention::A + feed_forward::F + attention_norm::AN + ffn_norm::FN end function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; @@ -155,42 +190,38 @@ function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hi ) end -function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) - h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) +function (block::TransformerBlock)(x, start_pos, rope, mask=nothing) + h = x + block.attention(block.attention_norm(x), start_pos, rope, mask) out = h + block.feed_forward(block.ffn_norm(h)) return out end -Flux.@layer TransformerBlock trainable=(attention, ) +Flux.@layer TransformerBlock trainable=(attention,) + -struct Transformer{T} - tok_embeddings::Flux.Embedding - layers::AbstractVector{TransformerBlock} - norm::RMSNorm{T} - output::Dense - freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} +struct Transformer{E<:Flux.Embedding,B<:AbstractVector{<:TransformerBlock},N<:RMSNorm,O<:Dense,R<:RoPE} + tok_embeddings::E + layers::B + norm::N + output::O + rope::R end -function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, - n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; - norm_eps::T=1f-5, - qkv_bias=false, - rope_theta::T=500000f0, - use_scaled_rope=false, - scale_factor=8) where T - +function Transformer( + vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, + n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; + norm_eps::T=1f-5, + qkv_bias=false, + rope_theta::T=500000f0, + use_scaled_rope=false, + scale_factor=8, +) where T tok_embeddings = Flux.Embedding(vocab_size => dim) layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers] norm = RMSNorm(dim, eps=norm_eps) output = Dense(dim => vocab_size, bias=false) - freqs_cis = precompute_freqs_cis( - dim ÷ n_heads, - max_seq_len * 2; - theta=rope_theta, - use_scaled=use_scaled_rope, - scale_factor=scale_factor - ) - Transformer(tok_embeddings, layers, norm, output, freqs_cis) + rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor) + Transformer(tok_embeddings, layers, norm, output, rope) end -Flux.@layer :expand Transformer trainable=(layers, ) +Flux.@layer :expand Transformer trainable=(layers,) diff --git a/src/model.jl b/src/model.jl index 9fdb3c4..6186b0b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,86 +1,21 @@ #Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 -function apply_scaling(freqs::AbstractVector; scale_factor=8) - #Hard-coded - I should move these to the main model struct and grab them from the config. - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 - ### - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = similar(freqs) - for (i, freq) in enumerate(freqs) - wavelen = 2 * π / freq - if wavelen < high_freq_wavelen - new_freqs[i] = freq - elseif wavelen > low_freq_wavelen - new_freqs[i] = freq / scale_factor - else - @assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / - (high_freq_factor - low_freq_factor) - new_freqs[i] = (1 - smooth) * freq / scale_factor + smooth * freq - end - end - return new_freqs -end - -function precompute_freqs_cis(dim::Int, end_pos::Int; - theta::T=10000f0, use_scaled=true, scale_factor=8) where T - freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim)) - if use_scaled - freqs = apply_scaling(freqs; scale_factor=scale_factor) - end - freqs_complex = cis.(T.(0:end_pos-1) * freqs') - cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len) - sin = permutedims(imag(freqs_complex), (2, 1)) - cos = reshape(cos, (dim÷2, end_pos, 1, 1)) - sin = reshape(sin, (dim÷2, end_pos, 1, 1)) - return cos, sin -end - - -#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 -#Use this one if you're using the Hugging Face weights. -function apply_rotary_emb(x, freqs_cis) - head_dim, seq_len, n_heads, batch = size(x) - x1 = @view x[1:head_dim÷2, :, :, :] - x2 = @view x[head_dim÷2+1:end, :, :, :] - cos, sin = freqs_cis - out = vcat( - x1 .* cos .- x2 .* sin, - x2 .* cos .+ x1 .* sin - ) - return out -end - -function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) +function update_kv_cache!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) seqlen = size(xk, 2) - cache.cache_k[:, (start_pos+1):(start_pos+seqlen), :, :] .= xk - cache.cache_v[:, (start_pos+1):(start_pos+seqlen), :, :] .= xv - return cache.cache_k[:, 1:(start_pos+seqlen), :, :], - cache.cache_v[:, 1:(start_pos+seqlen), :, :] + cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk + cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv + return cache.cache_k[:, 1:start_pos+seqlen, :, :], + cache.cache_v[:, 1:start_pos+seqlen, :, :] end -function repeat_kv(x::AbstractArray, n_rep::Int) - if n_rep == 1 - return x - end - return repeat(x, 1, n_rep, 1, 1) -end +repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) -function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, start_pos::Int) where T - seqlen = size(tokens, 1) # tokens expected as (seq_len, batch) +function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int) h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) - - # Get relevant freqs_cis slice - cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) - freqs_cis = (cos[:,start_pos+1:start_pos+seqlen,:,:], sin[:,start_pos+1:start_pos+seqlen,:,:]) - - + rope = model.rope[start_pos+1:start_pos+size(tokens, 1)] mask = create_mask(h) for layer in model.layers - h = layer(h, start_pos, freqs_cis, mask) + h = layer(h, start_pos, rope, mask) end h = model.norm(h) output = model.output(h) @@ -88,28 +23,24 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st end function create_mask(h::AbstractArray) - Flux.Zygote.ignore() do - embeddim, seqlen, batch = size(h) - mask = similar(h, seqlen, seqlen) - T = eltype(h) - mask .= T(-Inf) - #mask = triu(mask, 1) - mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup - return mask - end + embeddim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + T = eltype(h) + mask .= T(-Inf) + #mask = triu(mask, 1) + mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup + return mask end -function forward_loss(model::Transformer{T}, inputs::AbstractArray, +function forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; ignore_index::Int=-100, - mask = :auto) where T + mask = :auto) seqlen = size(inputs, 1) #(seq_len, batch) h = model.tok_embeddings(inputs) # (dim, seq_len, batch) cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:]) # Forward through layers (start_pos = 0 disables KV caching) - if mask == :auto - mask = create_mask(h) - end + mask = mask == :auto ? create_mask(h) : mask for layer in model.layers h = layer(h, 0, freqs_cis, mask) end @@ -137,70 +68,5 @@ function forward_loss(model::Transformer{T}, inputs::AbstractArray, return loss end - -#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 -#= -#Use this one if you're using the original Meta weights. -#You'll need to change the type of the freqs_cis field in Transformer to match. -function precompute_freqs_cis(dim::Int, end_pos::Int; - theta::Float32=10000f0, - use_scaled::Bool=true, scale_factor::Int=8) - # Create frequencies for the first half of dimensions - freqs = 1f0 ./ (theta .^ (Float32.(0:2:dim-1)[1:dim÷2] ./ dim)) - # Create position indices - note, using 0 indexing here because python consistency. Not sure if it makes a difference. - t = Float32.(0:end_pos-1) - if use_scaled - freqs = apply_scaling(freqs; scale_factor=scale_factor) - end - # Compute outer product - freqs = t * freqs' - # Convert to complex exponentials - # Note: Julia's cis(x) = exp(ix) = cos(x) + i*sin(x) - freqs_complex = cis.(freqs) - # Stack real and imaginary parts - # Note: Julia's reshape is similar to PyTorch's stack - freqs_cis_real = reshape( - reinterpret(Float32, reshape(freqs_complex, :)), - (2, size(freqs)...) - ) - # Permute to match PyTorch's dimension ordering - return permutedims(freqs_cis_real, (2,3,1)) -end - -function apply_rotary_emb(x, freqs_cis) - # x is (head_dim, seq_len, n_heads, batch) in Julia - # freqs_cis is (seq_len, head_dim/2, 2) - - #@show size(freqs_cis) - - # Reshape x to separate real/imaginary pairs - head_dim, seq_len, n_heads, batch = size(x) - x_reshaped = reshape(x, (2, head_dim÷2, seq_len, n_heads, batch)) - - # Reshape freqs_cis to broadcast correctly - # Note: reshape to (2, head_dim/2, seq_len, 1, 1) for broadcasting - freqs_cis = permutedims(freqs_cis, (3, 2, 1)) # now (2, head_dim/2, seq_len) - freqs_cis = reshape(freqs_cis, (2, size(freqs_cis, 2), size(freqs_cis, 3), 1, 1)) - - # Apply rotation using complex multiplication formula: - # (a + bi)(c + di) = (ac-bd) + (ad+bc)i - x_real = x_reshaped[1:1, :, :, :, :] - x_imag = x_reshaped[2:2, :, :, :, :] - f_real = freqs_cis[1:1, :, :, :, :] - f_imag = freqs_cis[2:2, :, :, :, :] - - #@show size(f_real) - #@show size(f_imag) - - #This is for checking the freqs_cis. - #Note: the cos, sin values are repeated in python - #g(f_real, f_imag) #passes - - out_real = x_real .* f_real .- x_imag .* f_imag - out_imag = x_imag .* f_real .+ x_real .* f_imag - - # Combine and reshape back - out = vcat(out_real, out_imag) - return reshape(out, (head_dim, seq_len, n_heads, batch)) -end -=# +# compat +forward_inference(model, args...) = model(args...) diff --git a/src/sampling.jl b/src/sampling.jl index ee51d3d..2e3969d 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -3,44 +3,40 @@ # Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. """ generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) - + Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. - tkn = llama3_tokenizer() - generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) +```julia +tkn = llama3_tokenizer() +generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) +``` """ function generate(model::Transformer{T}, - initial_tokens::AbstractArray{IntT}; + initial_tokens::AbstractArray{<:Integer}; max_new_tokens=100, sampler::Function=argmax_sampler, tokenizer_for_printing = nothing, end_token = 128010, - device = identity) where {T, IntT} - - # Initialize sequence with a new copy of the tokens + device = identity) where T + current_len = length(initial_tokens) - tokens = Vector{IntT}(undef, current_len + max_new_tokens) - tokens[1:current_len] = initial_tokens - # Set up KV caches for all attention layers + tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens)) + for layer in model.layers layer.attention.cache = KVCache( - T, # eltype - 1, # batch_size + T, 1, # eltype, batch_size current_len + max_new_tokens, # max possible sequence length layer.attention.n_kv_heads, layer.attention.head_dim, device = device ) end - # Process the initial sequence - if current_len > 0 - input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) - logits = forward_inference(model, input_tokens, 0) - start_pos = current_len - else - start_pos = 0 - end + + input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) + logits = model(input_tokens, 0) + start_pos = current_len + # Generate new tokens one at a time for _ in 1:max_new_tokens # If sequence is empty or we want to process just the last token @@ -50,17 +46,13 @@ function generate(model::Transformer{T}, input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token end # Get logits for next token - logits = forward_inference(model, input_tokens, start_pos) + logits = model(input_tokens, start_pos) # Sample next token (logits are size vocab × 1 × 1) next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token - if !isnothing(tokenizer_for_printing) - print(decode(tokenizer_for_printing, [next_token])) - end - if next_token == end_token - break - end + !isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token])) + next_token == end_token && break start_pos += 1 end # Clear KV caches From bc01519a92f60f3ee9cce9a956cf926503aba031 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Tue, 3 Dec 2024 04:02:16 +0100 Subject: [PATCH 2/6] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9d797c8..ca9724e 100644 --- a/Project.toml +++ b/Project.toml @@ -31,7 +31,7 @@ BytePairEncoding = "0.5" Distributions = "0.25" Flux = "0.14" LogitSamplers = "0.1" -LowRankLayers = "1.0.0" +LowRankLayers = "0.1" Metal = "1" NNlib = "0.9" ReactantCore = "0.1.2" From dc0bc2817d8d3de7df342665f07ba0ff9fd460b2 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 4 Dec 2024 03:43:19 +0100 Subject: [PATCH 3/6] Refactor --- Project.toml | 4 +- src/layers.jl | 113 +++++++++++++++++++++++++++++------------------- src/model.jl | 25 +++-------- src/sampling.jl | 30 ++++++------- src/utils.jl | 22 ++++++---- 5 files changed, 106 insertions(+), 88 deletions(-) diff --git a/Project.toml b/Project.toml index ca9724e..4a02757 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = ["murrellb and contributors"] version = "1.1.0-DEV" [deps] -BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" @@ -27,7 +27,7 @@ LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLay MetalExt = "Metal" [compat] -BytePairEncoding = "0.5" +Accessors = "0.1.38" Distributions = "0.25" Flux = "0.14" LogitSamplers = "0.1" diff --git a/src/layers.jl b/src/layers.jl index 2c3495a..b7ca584 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -1,16 +1,5 @@ const AnyDense = Union{Dense, LoRADense} -struct KVCache{T} - cache_k::AbstractArray{T,4} # (head_dim, seq_len, n_kv_heads, batch) - cache_v::AbstractArray{T,4} -end - -function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) - cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - KVCache(cache_k, cache_v) -end - struct FeedForward{W<:AnyDense} w1::W @@ -28,7 +17,7 @@ end (ff::FeedForward)(x) = ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -Flux.@layer :expand FeedForward +Flux.@layer FeedForward struct RMSNorm{T,W<:AbstractVector{T}} @@ -36,7 +25,7 @@ struct RMSNorm{T,W<:AbstractVector{T}} eps::T end -RMSNorm(dim::Int; eps::T=1f-5) where {T} = RMSNorm(ones(T, dim), eps) +RMSNorm(dim::Int; eps::T=1f-5) where T = RMSNorm(ones(T, dim), eps) function (norm::RMSNorm)(x) rms = sqrt.(sum(abs2, x, dims=1) ./ size(x,1) .+ norm.eps) @@ -51,7 +40,9 @@ struct RoPE{A<:AbstractArray} sin::A end -Base.getindex(rope::RoPE, i) = RoPE(selectdim(rope.cos, 2, i), selectdim(rope.sin, 2, i)) +Flux.@layer RoPE + +Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:]) function apply_scaling!(freqs::AbstractVector; scale_factor=8) #Hard-coded - I should move these to the main model struct and grab them from the config. @@ -93,8 +84,8 @@ end #Use this one if you're using the Hugging Face weights. function (rope::RoPE)(x) head_dim = size(x, 1) - x1 = @view x[1:head_dim÷2, :, :, :] - x2 = @view x[head_dim÷2+1:end, :, :, :] + x1 = x[1:head_dim÷2, :, :, :] + x2 = x[head_dim÷2+1:end, :, :, :] return vcat( x1 .* rope.cos .- x2 .* rope.sin, x2 .* rope.cos .+ x1 .* rope.sin @@ -102,36 +93,68 @@ function (rope::RoPE)(x) end -mutable struct Attention{Q,K,V,O} +mutable struct KVCache{T,A<:AbstractArray{T,4}} + cache_k::A + cache_v::A +end + +Flux.@layer KVCache + +function KVCache(T; batch_size=0, seq_length=0, n_kv_heads=0, head_dim=0) + cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + return KVCache(cache_k, cache_v) +end + +function reset_kv_cache!(cache::KVCache; batch_size, seq_length, n_kv_heads, head_dim) + cache.cache_k = similar(cache.cache_k, head_dim, seq_length, n_kv_heads, batch_size) .= 0 + cache.cache_v = similar(cache.cache_v, head_dim, seq_length, n_kv_heads, batch_size) .= 0 +end + +clear_kv_cache!(cache::KVCache) = reset_kv_cache!(cache, batch_size=0, seq_length=0, n_kv_heads=0, head_dim=0) + +function update_kv_cache!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) + seqlen = size(xk, 2) + cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk + cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv + return cache.cache_k[:, 1:start_pos+seqlen, :, :], + cache.cache_v[:, 1:start_pos+seqlen, :, :] +end + + +struct Attention{Q,K,V,O,C<:KVCache} wq::Q wk::K wv::V wo::O + dim::Int n_heads::Int n_kv_heads::Int head_dim::Int - n_rep::Int - #cache::Union{Nothing, KVCache} + cache::C end +Flux.@layer Attention trainable=(wq,wv) + function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) head_dim = dim ÷ n_heads - n_rep = n_heads ÷ n_kv_heads Attention( Dense(dim => n_heads * head_dim, bias=qkv_bias), Dense(dim => n_kv_heads * head_dim, bias=qkv_bias), Dense(dim => n_kv_heads * head_dim, bias=qkv_bias), Dense(n_heads * head_dim => dim, bias=false), + dim, n_heads, n_kv_heads, head_dim, - n_rep, - #nothing + KVCache(T) # starts off empty ) end -function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, rope::RoPE, mask=false) where T - dim, seqlen, batch = size(x) +repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) + +function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, rope=nothing, mask=false) where T + _, seqlen, batch = size(x) xq = attn.wq(x) xk = attn.wk(x) @@ -145,33 +168,33 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, rope::RoPE, mask xk = permutedims(xk, (1,3,2,4)) xv = permutedims(xv, (1,3,2,4)) - xq_rope = rope(xq) - xk_rope = rope(xk) + if rope isa RoPE + xq, xk = rope(xq), rope(xk) + end + + if attn.cache isa KVCache + xk, xv = update_kv_cache!(attn.cache, start_pos, xk, xv) + end - #@trace if !isnothing(attn.cache) - # xk_rope, xv = update_kv_cache!(attn.cache, start_pos, xk_rope, xv) - #end - xk_rope = repeat_kv(xk_rope, attn.n_rep) - xv = repeat_kv(xv, attn.n_rep) - - xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) - xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) - xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) + xk = repeat_kv(xk, attn.n_heads ÷ attn.n_kv_heads) + xv = repeat_kv(xv, attn.n_heads ÷ attn.n_kv_heads) + + xq_for_attn = reshape(xq, attn.head_dim, seqlen, :) + xk_for_attn = reshape(xk, attn.head_dim, seqlen, :) + xv_for_attn = reshape(xv, attn.head_dim, seqlen, :) scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) / sqrt(T(attn.head_dim)) scores .+= mask sm_scores = softmax(scores; dims=1) + output = batched_mul(xv_for_attn, sm_scores) e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) p_output = permutedims(e_output, (1,3,2,4)) - - r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) + r_output = reshape(p_output, (attn.n_heads * attn.head_dim, seqlen, batch)) proj = attn.wo(r_output) return proj end -Flux.@layer :expand Attention trainable=(wq,wv) - struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} attention::A @@ -180,8 +203,10 @@ struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} ffn_norm::FN end -function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5, qkv_bias=false) +function TransformerBlock( + dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim; + norm_eps=1f-5, qkv_bias=false, +) TransformerBlock( Attention(dim, n_heads, n_kv_heads; qkv_bias), FeedForward(dim, ff_hidden_dim), @@ -199,7 +224,7 @@ end Flux.@layer TransformerBlock trainable=(attention,) -struct Transformer{E<:Flux.Embedding,B<:AbstractVector{<:TransformerBlock},N<:RMSNorm,O<:Dense,R<:RoPE} +struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE} tok_embeddings::E layers::B norm::N @@ -217,11 +242,11 @@ function Transformer( scale_factor=8, ) where T tok_embeddings = Flux.Embedding(vocab_size => dim) - layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers] + layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers) norm = RMSNorm(dim, eps=norm_eps) output = Dense(dim => vocab_size, bias=false) rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor) Transformer(tok_embeddings, layers, norm, output, rope) end -Flux.@layer :expand Transformer trainable=(layers,) +Flux.@layer Transformer trainable=(layers,) diff --git a/src/model.jl b/src/model.jl index 6186b0b..bd76b99 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,15 +1,14 @@ #Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 -function update_kv_cache!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) - seqlen = size(xk, 2) - cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk - cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv - return cache.cache_k[:, 1:start_pos+seqlen, :, :], - cache.cache_v[:, 1:start_pos+seqlen, :, :] +function create_mask(h::AbstractArray{T}) where T<:AbstractFloat + dim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + mask .= T(-Inf) + #mask = triu(mask, 1) + mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup + return mask end -repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) - function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int) h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[start_pos+1:start_pos+size(tokens, 1)] @@ -22,16 +21,6 @@ function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int) return output end -function create_mask(h::AbstractArray) - embeddim, seqlen, batch = size(h) - mask = similar(h, seqlen, seqlen) - T = eltype(h) - mask .= T(-Inf) - #mask = triu(mask, 1) - mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup - return mask -end - function forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; ignore_index::Int=-100, mask = :auto) diff --git a/src/sampling.jl b/src/sampling.jl index 2e3969d..e475a7d 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -12,24 +12,24 @@ tkn = llama3_tokenizer() generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) ``` """ -function generate(model::Transformer{T}, - initial_tokens::AbstractArray{<:Integer}; - max_new_tokens=100, - sampler::Function=argmax_sampler, - tokenizer_for_printing = nothing, - end_token = 128010, - device = identity) where T - +function generate( + model::Transformer{T}, + initial_tokens::AbstractArray{<:Integer}; + max_new_tokens=100, + sampler::Function=argmax_sampler, + tokenizer_for_printing = nothing, + end_token = 128010, + device = identity +) where T current_len = length(initial_tokens) tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens)) for layer in model.layers - layer.attention.cache = KVCache( - T, 1, # eltype, batch_size - current_len + max_new_tokens, # max possible sequence length - layer.attention.n_kv_heads, - layer.attention.head_dim, - device = device + reset_kv_cache!(layer.attention.cache, + batch_size = 1, + seq_length = current_len + max_new_tokens, + n_kv_heads = layer.attention.n_kv_heads, + head_dim = layer.attention.head_dim ) end @@ -57,7 +57,7 @@ function generate(model::Transformer{T}, end # Clear KV caches for layer in model.layers - layer.attention.cache = nothing + clear_kv_cache!(layer.attention.cache) end return tokens[1:current_len] end diff --git a/src/utils.jl b/src/utils.jl index c8758e3..65e7b5c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,5 @@ +using Accessors + encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1 decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...) @@ -41,7 +43,10 @@ so if you're loading weights from a different source, you might get very poor mo model_weight_paths = ["Llama3_2_1B_instruct/model.safetensors"] #Can be an array of paths if the model is split across multiple files model = load_llama3_from_safetensors(model_weight_paths, config) """ -function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32, add_lora_to = Symbol[], lora_dim = 0) +function load_llama3_from_safetensors( + paths::Vector{String}, config; + T = Float32, add_lora_to = Symbol[], lora_dim = 0, +) config = Dict(config) #Just in case the user passed eg. a JSON3.Object #@assert config[:rope_scaling][:rope_type] == "llama3" #@assert config[:rope_scaling][:low_freq_factor] == 1 @@ -144,40 +149,39 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 end if !isempty(add_lora_to) - #Then load in the current layers: if :Q in add_lora_to for layer in model.layers - layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) + @reset layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) end end if :K in add_lora_to for layer in model.layers - layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) + @reset layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) end end if :V in add_lora_to for layer in model.layers - layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) + @reset layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) end end if :O in add_lora_to for layer in model.layers - layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) + @reset layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) end end if :w1 in add_lora_to for layer in model.layers - layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) + @reset layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) end end if :w2 in add_lora_to for layer in model.layers - layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) + @reset layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) end end if :w3 in add_lora_to for layer in model.layers - layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) + @reset layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) end end end From 8bffa9e7d92a7f8fd7295ef4a6785ed3a379f5f3 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 4 Dec 2024 22:33:39 +0100 Subject: [PATCH 4/6] Refactor and fixes --- src/Jjama3.jl | 7 +++-- src/cache.jl | 35 +++++++++++++++++++++++++ src/layers.jl | 69 ++++++++++++++----------------------------------- src/model.jl | 10 +++---- src/sampling.jl | 9 ++----- 5 files changed, 66 insertions(+), 64 deletions(-) create mode 100644 src/cache.jl diff --git a/src/Jjama3.jl b/src/Jjama3.jl index 06976ef..bd0862d 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -15,13 +15,16 @@ using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained const tokenizer_from_file = HuggingFaceTokenizers.from_file -include("layers.jl") +include("cache.jl") export KVCache + +include("layers.jl") export FeedForward export RMSNorm +export RoPE +export Attention export TransformerBlock export Transformer -export RoPE include("model.jl") export forward_loss diff --git a/src/cache.jl b/src/cache.jl new file mode 100644 index 0000000..8317e59 --- /dev/null +++ b/src/cache.jl @@ -0,0 +1,35 @@ +mutable struct KVCache{T,A<:AbstractArray{T,4}} + head_dim::Int + n_kv_heads::Int + seq_length::Int + batch_size::Int + cache_k::A + cache_v::A +end + +Flux.@layer KVCache + +function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1) + cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + return KVCache(head_dim, n_kv_heads, seq_length, batch_size, cache_k, cache_v) +end + +function config!(cache::KVCache; seq_length=cache.seq_length, batch_size=cache.batch_size) + cache.cache_k = similar(cache.cache_k, cache.head_dim, seq_length, cache.n_kv_heads, batch_size) .= 0 + cache.cache_v = similar(cache.cache_v, cache.head_dim, seq_length, cache.n_kv_heads, batch_size) .= 0 +end + +clear!(cache::KVCache) = config!(cache, seq_length=0) + +function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) + #if iszero(cache.seq_length) + # return xk, xv + #else + seqlen = size(xk, 2) + cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk + cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv + return cache.cache_k[:, 1:start_pos+seqlen, :, :], + cache.cache_v[:, 1:start_pos+seqlen, :, :] + #end +end diff --git a/src/layers.jl b/src/layers.jl index b7ca584..6b14443 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -7,6 +7,8 @@ struct FeedForward{W<:AnyDense} w3::W end +Flux.@layer FeedForward + function FeedForward(dim::Int, ff_hidden_dim::Int) FeedForward( Dense(dim => ff_hidden_dim, bias=false), @@ -17,23 +19,21 @@ end (ff::FeedForward)(x) = ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -Flux.@layer FeedForward - -struct RMSNorm{T,W<:AbstractVector{T}} +struct RMSNorm{T<:AbstractFloat,W<:AbstractVector{T}} weight::W eps::T end +Flux.@layer RMSNorm + RMSNorm(dim::Int; eps::T=1f-5) where T = RMSNorm(ones(T, dim), eps) function (norm::RMSNorm)(x) - rms = sqrt.(sum(abs2, x, dims=1) ./ size(x,1) .+ norm.eps) + rms = sqrt.(sum(abs2, x, dims=1) / size(x, 1) .+ norm.eps) return x .* (norm.weight ./ rms) end -Flux.@layer RMSNorm - struct RoPE{A<:AbstractArray} cos::A @@ -45,7 +45,7 @@ Flux.@layer RoPE Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:]) function apply_scaling!(freqs::AbstractVector; scale_factor=8) - #Hard-coded - I should move these to the main model struct and grab them from the config. + #Hard-coded - should move these to the main model struct and grab them from the config. low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 @@ -80,8 +80,9 @@ function RoPE( return RoPE(cos, sin) end -#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 -#Use this one if you're using the Hugging Face weights. +# Note about Huggingface weights and rotary embeddings: +# https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +# Use this one if you're using the Hugging Face weights. function (rope::RoPE)(x) head_dim = size(x, 1) x1 = x[1:head_dim÷2, :, :, :] @@ -93,36 +94,7 @@ function (rope::RoPE)(x) end -mutable struct KVCache{T,A<:AbstractArray{T,4}} - cache_k::A - cache_v::A -end - -Flux.@layer KVCache - -function KVCache(T; batch_size=0, seq_length=0, n_kv_heads=0, head_dim=0) - cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) - cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) - return KVCache(cache_k, cache_v) -end - -function reset_kv_cache!(cache::KVCache; batch_size, seq_length, n_kv_heads, head_dim) - cache.cache_k = similar(cache.cache_k, head_dim, seq_length, n_kv_heads, batch_size) .= 0 - cache.cache_v = similar(cache.cache_v, head_dim, seq_length, n_kv_heads, batch_size) .= 0 -end - -clear_kv_cache!(cache::KVCache) = reset_kv_cache!(cache, batch_size=0, seq_length=0, n_kv_heads=0, head_dim=0) - -function update_kv_cache!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) - seqlen = size(xk, 2) - cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk - cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv - return cache.cache_k[:, 1:start_pos+seqlen, :, :], - cache.cache_v[:, 1:start_pos+seqlen, :, :] -end - - -struct Attention{Q,K,V,O,C<:KVCache} +struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache} wq::Q wk::K wv::V @@ -147,13 +119,13 @@ function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) n_heads, n_kv_heads, head_dim, - KVCache(T) # starts off empty + KVCache(Float32; n_kv_heads, head_dim), ) end repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) -function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, rope=nothing, mask=false) where T +function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T _, seqlen, batch = size(x) xq = attn.wq(x) @@ -172,16 +144,15 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, rope=nothing, ma xq, xk = rope(xq), rope(xk) end - if attn.cache isa KVCache - xk, xv = update_kv_cache!(attn.cache, start_pos, xk, xv) - end + # Update if cache is configured with seq_length > 0 + xk, xv = update!(attn.cache, start_pos, xk, xv) xk = repeat_kv(xk, attn.n_heads ÷ attn.n_kv_heads) xv = repeat_kv(xv, attn.n_heads ÷ attn.n_kv_heads) - xq_for_attn = reshape(xq, attn.head_dim, seqlen, :) - xk_for_attn = reshape(xk, attn.head_dim, seqlen, :) - xv_for_attn = reshape(xv, attn.head_dim, seqlen, :) + xq_for_attn = reshape(xq, attn.head_dim, :, attn.n_heads * batch) + xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch) + xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) / sqrt(T(attn.head_dim)) scores .+= mask @@ -232,6 +203,8 @@ struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNo rope::R end +Flux.@layer Transformer trainable=(layers,) + function Transformer( vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; @@ -248,5 +221,3 @@ function Transformer( rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor) Transformer(tok_embeddings, layers, norm, output, rope) end - -Flux.@layer Transformer trainable=(layers,) diff --git a/src/model.jl b/src/model.jl index bd76b99..179621d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -9,7 +9,7 @@ function create_mask(h::AbstractArray{T}) where T<:AbstractFloat return mask end -function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int) +function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int=0) h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) rope = model.rope[start_pos+1:start_pos+size(tokens, 1)] mask = create_mask(h) @@ -26,12 +26,10 @@ function forward_loss(model::Transformer, inputs::AbstractArray, mask = :auto) seqlen = size(inputs, 1) #(seq_len, batch) h = model.tok_embeddings(inputs) # (dim, seq_len, batch) - cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) - freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:]) - # Forward through layers (start_pos = 0 disables KV caching) - mask = mask == :auto ? create_mask(h) : mask + rope = model.rope[1:seqlen] + mask = create_mask(h) for layer in model.layers - h = layer(h, 0, freqs_cis, mask) + h = layer(h, 0, rope, mask) end h = model.norm(h) logits = model.output(h) diff --git a/src/sampling.jl b/src/sampling.jl index e475a7d..50abebe 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -25,12 +25,7 @@ function generate( tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens)) for layer in model.layers - reset_kv_cache!(layer.attention.cache, - batch_size = 1, - seq_length = current_len + max_new_tokens, - n_kv_heads = layer.attention.n_kv_heads, - head_dim = layer.attention.head_dim - ) + config!(layer.attention.cache, seq_length = current_len + max_new_tokens) end input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) @@ -57,7 +52,7 @@ function generate( end # Clear KV caches for layer in model.layers - clear_kv_cache!(layer.attention.cache) + clear!(layer.attention.cache) end return tokens[1:current_len] end From 0189aa98dab41a30de760d0dad718d1a2f6b44f0 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 4 Dec 2024 22:36:11 +0100 Subject: [PATCH 5/6] rm ReactantCore --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4a02757..9ef394d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogitSamplers = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" LowRankLayers = "b66182ab-a85c-43b0-99bd-d85cc47c5e50" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -34,7 +33,6 @@ LogitSamplers = "0.1" LowRankLayers = "0.1" Metal = "1" NNlib = "0.9" -ReactantCore = "0.1.2" SafeTensors = "1" StatsBase = "0.34" julia = "1.11" From aecca480dcb49c35bd67b8a8c24faa9b7033aa9f Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Wed, 4 Dec 2024 22:43:28 +0100 Subject: [PATCH 6/6] Fixes, conditional caching --- src/Jjama3.jl | 1 - src/cache.jl | 26 ++++++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/Jjama3.jl b/src/Jjama3.jl index bd0862d..d985039 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -8,7 +8,6 @@ using StatsBase using NNlib using LogitSamplers using LowRankLayers -using ReactantCore using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer diff --git a/src/cache.jl b/src/cache.jl index 8317e59..eda900e 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -1,35 +1,37 @@ mutable struct KVCache{T,A<:AbstractArray{T,4}} - head_dim::Int - n_kv_heads::Int - seq_length::Int - batch_size::Int cache_k::A cache_v::A end Flux.@layer KVCache +head_dim(cache::KVCache) = size(cache.cache_k, 1) +seq_length(cache::KVCache) = size(cache.cache_k, 2) +n_kv_heads(cache::KVCache) = size(cache.cache_k, 3) +batch_size(cache::KVCache) = size(cache.cache_k, 4) + function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1) cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) - return KVCache(head_dim, n_kv_heads, seq_length, batch_size, cache_k, cache_v) + return KVCache(cache_k, cache_v) end -function config!(cache::KVCache; seq_length=cache.seq_length, batch_size=cache.batch_size) - cache.cache_k = similar(cache.cache_k, cache.head_dim, seq_length, cache.n_kv_heads, batch_size) .= 0 - cache.cache_v = similar(cache.cache_v, cache.head_dim, seq_length, cache.n_kv_heads, batch_size) .= 0 +function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_size(cache)) + cache.cache_k = similar(cache.cache_k, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0 + cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0 end clear!(cache::KVCache) = config!(cache, seq_length=0) function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) - #if iszero(cache.seq_length) - # return xk, xv - #else + if iszero(seq_length(cache)) + println("fuck") + return xk, xv + else seqlen = size(xk, 2) cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv return cache.cache_k[:, 1:start_pos+seqlen, :, :], cache.cache_v[:, 1:start_pos+seqlen, :, :] - #end + end end