diff --git a/Project.toml b/Project.toml index a9df5ce..9ef394d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "Jjama3" uuid = "1285d783-1a6d-4703-8f05-8ac83ef55592" authors = ["murrellb and contributors"] -version = "1.0.0-DEV" +version = "1.1.0-DEV" [deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" @@ -19,13 +20,13 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [sources] HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"} -LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"} LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"} [extensions] MetalExt = "Metal" [compat] +Accessors = "0.1.38" Distributions = "0.25" Flux = "0.14" LogitSamplers = "0.1" diff --git a/README.md b/README.md index 22e9fe3..1576d93 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ generate(model, prompt, - RoPE scaling (for exceeding the model's max training-time context length) is implemented, but likely incorrect with KV cache. Be careful if you're using with really long sequences. - Imported models are trainable (with Flux), including with low-rank (ie. LoRA) finetuning. - Sampling, training, etc compatible with CUDA, where everything is much faster. -- Metal acceleration for forward_inference, forward_loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU. +- Metal acceleration for forward inference, forward loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU. ## Samplers diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl index 5383601..85ce17d 100644 --- a/ext/MetalExt.jl +++ b/ext/MetalExt.jl @@ -1,7 +1,9 @@ module MetalExt -#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling. -#It seems that each single Metal call has some constant overhead that kills it. +# See https://github.com/FluxML/NNlib.jl/pull/614 + +# Note: Metal speeds things up a little for forward inference and forward_loss calls, but is VERY slow for sampling. +# It seems that each single Metal call has some constant overhead that kills it. using Metal, NNlib diff --git a/src/Jjama3.jl b/src/Jjama3.jl index f0db5b6..d985039 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,44 +1,53 @@ module Jjama3 -using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib -using LogitSamplers, LowRankLayers -import HuggingFaceTokenizers +using Flux +using SafeTensors +using Distributions +using LinearAlgebra +using StatsBase +using NNlib +using LogitSamplers +using LowRankLayers +using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained const tokenizer_from_file = HuggingFaceTokenizers.from_file -const Tokenizer = HuggingFaceTokenizers.Tokenizer - -const top_pk_sampler = LogitSamplers.top_pk_sampler -const argmax_sampler = LogitSamplers.argmax_sampler -const min_p_sampler = LogitSamplers.min_p_sampler -const top_nσ_sampler = LogitSamplers.top_nσ_sampler - +include("cache.jl") +export KVCache include("layers.jl") +export FeedForward +export RMSNorm +export RoPE +export Attention +export TransformerBlock +export Transformer + include("model.jl") -include("utils.jl") +export forward_loss +export forward_inference + include("sampling.jl") +export top_pk_sampler +export argmax_sampler +export top_nσ_sampler +export min_p_sampler +export generate +export tokenizer_from_repo +export tokenizer_from_file +export Tokenizer -export load_llama321B_from_safetensors, - load_llama3_from_safetensors, - generate, - forward_loss, - forward_inference, - top_pk_sampler, - argmax_sampler, - top_nσ_sampler, - min_p_sampler, - tokenizer_from_repo, - tokenizer_from_file, - encode, - decode, - Tokenizer, - llama3_instruct_prompt, - llama3_assistant_prompt, - smollm2_instruct_prompt, - smollm2_assistant_prompt, - structured_choice +include("utils.jl") +export encode +export decode +export load_llama321B_from_safetensors +export load_llama3_from_safetensors +export llama3_instruct_prompt +export llama3_assistant_prompt +export smollm2_instruct_prompt +export smollm2_assistant_prompt +export structured_choice end diff --git a/src/cache.jl b/src/cache.jl new file mode 100644 index 0000000..eda900e --- /dev/null +++ b/src/cache.jl @@ -0,0 +1,37 @@ +mutable struct KVCache{T,A<:AbstractArray{T,4}} + cache_k::A + cache_v::A +end + +Flux.@layer KVCache + +head_dim(cache::KVCache) = size(cache.cache_k, 1) +seq_length(cache::KVCache) = size(cache.cache_k, 2) +n_kv_heads(cache::KVCache) = size(cache.cache_k, 3) +batch_size(cache::KVCache) = size(cache.cache_k, 4) + +function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1) + cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) + return KVCache(cache_k, cache_v) +end + +function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_size(cache)) + cache.cache_k = similar(cache.cache_k, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0 + cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0 +end + +clear!(cache::KVCache) = config!(cache, seq_length=0) + +function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) + if iszero(seq_length(cache)) + println("fuck") + return xk, xv + else + seqlen = size(xk, 2) + cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk + cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv + return cache.cache_k[:, 1:start_pos+seqlen, :, :], + cache.cache_v[:, 1:start_pos+seqlen, :, :] + end +end diff --git a/src/layers.jl b/src/layers.jl index 73fd654..6b14443 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -1,20 +1,14 @@ -struct KVCache{T} - cache_k::AbstractArray{T, 4} # (head_dim, seq_len, n_kv_heads, batch) - cache_v::AbstractArray{T, 4} -end +const AnyDense = Union{Dense, LoRADense} -function KVCache(T, batch_size::Int, seq_length::Int, n_kv_heads::Int, head_dim::Int; device = identity) - cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) |> device - KVCache(cache_k, cache_v) -end -struct FeedForward - w1::Union{Dense, LoRADense} - w2::Union{Dense, LoRADense} - w3::Union{Dense, LoRADense} +struct FeedForward{W<:AnyDense} + w1::W + w2::W + w3::W end +Flux.@layer FeedForward + function FeedForward(dim::Int, ff_hidden_dim::Int) FeedForward( Dense(dim => ff_hidden_dim, bias=false), @@ -23,58 +17,116 @@ function FeedForward(dim::Int, ff_hidden_dim::Int) ) end -function (ff::FeedForward)(x) - return ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -end +(ff::FeedForward)(x) = ff.w2(Flux.swish(ff.w1(x)) .* ff.w3(x)) -Flux.@layer :expand FeedForward -struct RMSNorm{T} - weight::AbstractVector{T} +struct RMSNorm{T<:AbstractFloat,W<:AbstractVector{T}} + weight::W eps::T end -function RMSNorm(dim::Int; eps::T=1f-5) where T - RMSNorm{T}(ones(T, dim), eps) -end +Flux.@layer RMSNorm + +RMSNorm(dim::Int; eps::T=1f-5) where T = RMSNorm(ones(T, dim), eps) function (norm::RMSNorm)(x) - rms = sqrt.(sum(abs2.(x), dims=1) ./ size(x,1) .+ norm.eps) + rms = sqrt.(sum(abs2, x, dims=1) / size(x, 1) .+ norm.eps) return x .* (norm.weight ./ rms) end -Flux.@layer RMSNorm -mutable struct Attention - wq::Union{Dense, LoRADense} - wk::Union{Dense, LoRADense} - wv::Union{Dense, LoRADense} - wo::Union{Dense, LoRADense} +struct RoPE{A<:AbstractArray} + cos::A + sin::A +end + +Flux.@layer RoPE + +Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:]) + +function apply_scaling!(freqs::AbstractVector; scale_factor=8) + #Hard-coded - should move these to the main model struct and grab them from the config. + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 + ### + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + for (i, freq) in enumerate(freqs) + wavelen = 2π / freq + if wavelen > low_freq_wavelen + freqs[i] = freq / scale_factor + elseif wavelen > high_freq_wavelen + @assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor) + freqs[i] = (1 - smooth) * freq / scale_factor + smooth * freq + end + end + return freqs +end + +function RoPE( + dim::Int, end_pos::Int; + theta::T=10000f0, use_scaled=true, scale_factor=8, +) where T + freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim)) + use_scaled && apply_scaling!(freqs; scale_factor) + freqs_complex = cis.(T.(0:end_pos-1) * freqs') + cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len) + sin = permutedims(imag(freqs_complex), (2, 1)) + cos = reshape(cos, (dim÷2, end_pos, 1, 1)) + sin = reshape(sin, (dim÷2, end_pos, 1, 1)) + return RoPE(cos, sin) +end + +# Note about Huggingface weights and rotary embeddings: +# https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +# Use this one if you're using the Hugging Face weights. +function (rope::RoPE)(x) + head_dim = size(x, 1) + x1 = x[1:head_dim÷2, :, :, :] + x2 = x[head_dim÷2+1:end, :, :, :] + return vcat( + x1 .* rope.cos .- x2 .* rope.sin, + x2 .* rope.cos .+ x1 .* rope.sin + ) +end + + +struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache} + wq::Q + wk::K + wv::V + wo::O + dim::Int n_heads::Int n_kv_heads::Int head_dim::Int - n_rep::Int - cache::Union{Nothing, KVCache} + cache::C end +Flux.@layer Attention trainable=(wq,wv) + function Attention(dim::Int, n_heads::Int, n_kv_heads=n_heads; qkv_bias=false) head_dim = dim ÷ n_heads - n_rep = n_heads ÷ n_kv_heads Attention( Dense(dim => n_heads * head_dim, bias=qkv_bias), Dense(dim => n_kv_heads * head_dim, bias=qkv_bias), Dense(dim => n_kv_heads * head_dim, bias=qkv_bias), Dense(n_heads * head_dim => dim, bias=false), + dim, n_heads, n_kv_heads, head_dim, - n_rep, - nothing + KVCache(Float32; n_kv_heads, head_dim), ) end -function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=nothing) where T - dim, seqlen, batch = size(x) +repeat_kv(x::AbstractArray, n_rep::Int) = isone(n_rep) ? x : repeat(x, 1, n_rep, 1, 1) + +function (attn::Attention)(x::AbstractArray{T}, start_pos::Integer, rope=nothing, mask=false) where T + _, seqlen, batch = size(x) xq = attn.wq(x) xk = attn.wk(x) @@ -84,69 +136,48 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Lazy permute dims. Need to test CUDA. Note: test fails. - #xq = PermutedDimsArray(xq, (1,3,2,4)) - #xk = PermutedDimsArray(xk, (1,3,2,4)) - #xv = PermutedDimsArray(xv, (1,3,2,4)) - xq = permutedims(xq, (1,3,2,4)) xk = permutedims(xk, (1,3,2,4)) xv = permutedims(xv, (1,3,2,4)) - xq_rope = apply_rotary_emb(xq, freqs_cis) - xk_rope = apply_rotary_emb(xk, freqs_cis) - - if !isnothing(attn.cache) - xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) + if rope isa RoPE + xq, xk = rope(xq), rope(xk) end - xk_rope = repeat_kv(xk_rope, attn.n_rep) - xv = repeat_kv(xv, attn.n_rep) - - xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) - xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) + # Update if cache is configured with seq_length > 0 + xk, xv = update!(attn.cache, start_pos, xk, xv) + + xk = repeat_kv(xk, attn.n_heads ÷ attn.n_kv_heads) + xv = repeat_kv(xv, attn.n_heads ÷ attn.n_kv_heads) + + xq_for_attn = reshape(xq, attn.head_dim, :, attn.n_heads * batch) + xk_for_attn = reshape(xk, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - - #= - scores = batched_mul( - permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) - #batched_transpose(xq_for_attn), # (seqlen, head_dim, batch*heads) - xk_for_attn # (head_dim, seqlen, batch*heads) - ) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end - sm_scores = softmax(scores; dims=2) - output = batched_mul(sm_scores, permutedims(xv_for_attn, (2,1,3))) - e_output = reshape(output, (seqlen, attn.head_dim, attn.n_heads, batch)) - p_output = permutedims(e_output, (2,3,1,4)) # (n_heads, head_dim, seqlen, batch) - =# - - scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) ./ sqrt(T(attn.head_dim)) - if !isnothing(mask) - scores = scores .+ mask - end + + scores = batched_mul(batched_transpose(xk_for_attn), xq_for_attn) / sqrt(T(attn.head_dim)) + scores .+= mask sm_scores = softmax(scores; dims=1) + output = batched_mul(xv_for_attn, sm_scores) e_output = reshape(output, (attn.head_dim, seqlen, attn.n_heads, batch)) p_output = permutedims(e_output, (1,3,2,4)) - - r_output = reshape(p_output, (attn.head_dim * attn.n_heads, seqlen, batch)) + r_output = reshape(p_output, (attn.n_heads * attn.head_dim, seqlen, batch)) proj = attn.wo(r_output) return proj end -Flux.@layer :expand Attention trainable=(wq, wv) -struct TransformerBlock - attention::Attention - feed_forward::FeedForward - attention_norm::RMSNorm - ffn_norm::RMSNorm +struct TransformerBlock{A<:Attention,F<:FeedForward,AN<:RMSNorm,FN<:RMSNorm} + attention::A + feed_forward::F + attention_norm::AN + ffn_norm::FN end -function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hidden_dim = 4 * dim; - norm_eps=1f-5, qkv_bias=false) +function TransformerBlock( + dim::Int, n_heads::Int, n_kv_heads::Int = n_heads, ff_hidden_dim = 4 * dim; + norm_eps=1f-5, qkv_bias=false, +) TransformerBlock( Attention(dim, n_heads, n_kv_heads; qkv_bias), FeedForward(dim, ff_hidden_dim), @@ -155,42 +186,38 @@ function TransformerBlock(dim::Int, n_heads::Int, n_kv_heads::Int=n_heads, ff_hi ) end -function (block::TransformerBlock)(x, start_pos, freqs_cis, mask=nothing) - h = x + block.attention(block.attention_norm(x), start_pos, freqs_cis, mask) +function (block::TransformerBlock)(x, start_pos, rope, mask=nothing) + h = x + block.attention(block.attention_norm(x), start_pos, rope, mask) out = h + block.feed_forward(block.ffn_norm(h)) return out end -Flux.@layer TransformerBlock trainable=(attention, ) +Flux.@layer TransformerBlock trainable=(attention,) + -struct Transformer{T} - tok_embeddings::Flux.Embedding - layers::AbstractVector{TransformerBlock} - norm::RMSNorm{T} - output::Dense - freqs_cis::Tuple{AbstractArray{T, 4}, AbstractArray{T, 4}} +struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE} + tok_embeddings::E + layers::B + norm::N + output::O + rope::R end -function Transformer(vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, - n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; - norm_eps::T=1f-5, - qkv_bias=false, - rope_theta::T=500000f0, - use_scaled_rope=false, - scale_factor=8) where T - +Flux.@layer Transformer trainable=(layers,) + +function Transformer( + vocab_size::Int, dim::Int, n_layers::Int, n_heads::Int, + n_kv_heads::Int, max_seq_len::Int, ff_hidden_dim::Int; + norm_eps::T=1f-5, + qkv_bias=false, + rope_theta::T=500000f0, + use_scaled_rope=false, + scale_factor=8, +) where T tok_embeddings = Flux.Embedding(vocab_size => dim) - layers = [TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers] + layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers) norm = RMSNorm(dim, eps=norm_eps) output = Dense(dim => vocab_size, bias=false) - freqs_cis = precompute_freqs_cis( - dim ÷ n_heads, - max_seq_len * 2; - theta=rope_theta, - use_scaled=use_scaled_rope, - scale_factor=scale_factor - ) - Transformer(tok_embeddings, layers, norm, output, freqs_cis) + rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor) + Transformer(tok_embeddings, layers, norm, output, rope) end - -Flux.@layer :expand Transformer trainable=(layers, ) diff --git a/src/model.jl b/src/model.jl index 9fdb3c4..179621d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,117 +1,35 @@ #Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 -function apply_scaling(freqs::AbstractVector; scale_factor=8) - #Hard-coded - I should move these to the main model struct and grab them from the config. - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 - ### - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = similar(freqs) - for (i, freq) in enumerate(freqs) - wavelen = 2 * π / freq - if wavelen < high_freq_wavelen - new_freqs[i] = freq - elseif wavelen > low_freq_wavelen - new_freqs[i] = freq / scale_factor - else - @assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / - (high_freq_factor - low_freq_factor) - new_freqs[i] = (1 - smooth) * freq / scale_factor + smooth * freq - end - end - return new_freqs -end - -function precompute_freqs_cis(dim::Int, end_pos::Int; - theta::T=10000f0, use_scaled=true, scale_factor=8) where T - freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim)) - if use_scaled - freqs = apply_scaling(freqs; scale_factor=scale_factor) - end - freqs_complex = cis.(T.(0:end_pos-1) * freqs') - cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len) - sin = permutedims(imag(freqs_complex), (2, 1)) - cos = reshape(cos, (dim÷2, end_pos, 1, 1)) - sin = reshape(sin, (dim÷2, end_pos, 1, 1)) - return cos, sin -end - - -#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 -#Use this one if you're using the Hugging Face weights. -function apply_rotary_emb(x, freqs_cis) - head_dim, seq_len, n_heads, batch = size(x) - x1 = @view x[1:head_dim÷2, :, :, :] - x2 = @view x[head_dim÷2+1:end, :, :, :] - cos, sin = freqs_cis - out = vcat( - x1 .* cos .- x2 .* sin, - x2 .* cos .+ x1 .* sin - ) - return out +function create_mask(h::AbstractArray{T}) where T<:AbstractFloat + dim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + mask .= T(-Inf) + #mask = triu(mask, 1) + mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup + return mask end -function update_kv_cache(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) - seqlen = size(xk, 2) - cache.cache_k[:, (start_pos+1):(start_pos+seqlen), :, :] .= xk - cache.cache_v[:, (start_pos+1):(start_pos+seqlen), :, :] .= xv - return cache.cache_k[:, 1:(start_pos+seqlen), :, :], - cache.cache_v[:, 1:(start_pos+seqlen), :, :] -end - -function repeat_kv(x::AbstractArray, n_rep::Int) - if n_rep == 1 - return x - end - return repeat(x, 1, n_rep, 1, 1) -end - -function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, start_pos::Int) where T - seqlen = size(tokens, 1) # tokens expected as (seq_len, batch) +function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int=0) h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch) - - # Get relevant freqs_cis slice - cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) - freqs_cis = (cos[:,start_pos+1:start_pos+seqlen,:,:], sin[:,start_pos+1:start_pos+seqlen,:,:]) - - + rope = model.rope[start_pos+1:start_pos+size(tokens, 1)] mask = create_mask(h) for layer in model.layers - h = layer(h, start_pos, freqs_cis, mask) + h = layer(h, start_pos, rope, mask) end h = model.norm(h) output = model.output(h) return output end -function create_mask(h::AbstractArray) - Flux.Zygote.ignore() do - embeddim, seqlen, batch = size(h) - mask = similar(h, seqlen, seqlen) - T = eltype(h) - mask .= T(-Inf) - #mask = triu(mask, 1) - mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup - return mask - end -end - -function forward_loss(model::Transformer{T}, inputs::AbstractArray, +function forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; ignore_index::Int=-100, - mask = :auto) where T + mask = :auto) seqlen = size(inputs, 1) #(seq_len, batch) h = model.tok_embeddings(inputs) # (dim, seq_len, batch) - cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) - freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:]) - # Forward through layers (start_pos = 0 disables KV caching) - if mask == :auto - mask = create_mask(h) - end + rope = model.rope[1:seqlen] + mask = create_mask(h) for layer in model.layers - h = layer(h, 0, freqs_cis, mask) + h = layer(h, 0, rope, mask) end h = model.norm(h) logits = model.output(h) @@ -137,70 +55,5 @@ function forward_loss(model::Transformer{T}, inputs::AbstractArray, return loss end - -#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 -#= -#Use this one if you're using the original Meta weights. -#You'll need to change the type of the freqs_cis field in Transformer to match. -function precompute_freqs_cis(dim::Int, end_pos::Int; - theta::Float32=10000f0, - use_scaled::Bool=true, scale_factor::Int=8) - # Create frequencies for the first half of dimensions - freqs = 1f0 ./ (theta .^ (Float32.(0:2:dim-1)[1:dim÷2] ./ dim)) - # Create position indices - note, using 0 indexing here because python consistency. Not sure if it makes a difference. - t = Float32.(0:end_pos-1) - if use_scaled - freqs = apply_scaling(freqs; scale_factor=scale_factor) - end - # Compute outer product - freqs = t * freqs' - # Convert to complex exponentials - # Note: Julia's cis(x) = exp(ix) = cos(x) + i*sin(x) - freqs_complex = cis.(freqs) - # Stack real and imaginary parts - # Note: Julia's reshape is similar to PyTorch's stack - freqs_cis_real = reshape( - reinterpret(Float32, reshape(freqs_complex, :)), - (2, size(freqs)...) - ) - # Permute to match PyTorch's dimension ordering - return permutedims(freqs_cis_real, (2,3,1)) -end - -function apply_rotary_emb(x, freqs_cis) - # x is (head_dim, seq_len, n_heads, batch) in Julia - # freqs_cis is (seq_len, head_dim/2, 2) - - #@show size(freqs_cis) - - # Reshape x to separate real/imaginary pairs - head_dim, seq_len, n_heads, batch = size(x) - x_reshaped = reshape(x, (2, head_dim÷2, seq_len, n_heads, batch)) - - # Reshape freqs_cis to broadcast correctly - # Note: reshape to (2, head_dim/2, seq_len, 1, 1) for broadcasting - freqs_cis = permutedims(freqs_cis, (3, 2, 1)) # now (2, head_dim/2, seq_len) - freqs_cis = reshape(freqs_cis, (2, size(freqs_cis, 2), size(freqs_cis, 3), 1, 1)) - - # Apply rotation using complex multiplication formula: - # (a + bi)(c + di) = (ac-bd) + (ad+bc)i - x_real = x_reshaped[1:1, :, :, :, :] - x_imag = x_reshaped[2:2, :, :, :, :] - f_real = freqs_cis[1:1, :, :, :, :] - f_imag = freqs_cis[2:2, :, :, :, :] - - #@show size(f_real) - #@show size(f_imag) - - #This is for checking the freqs_cis. - #Note: the cos, sin values are repeated in python - #g(f_real, f_imag) #passes - - out_real = x_real .* f_real .- x_imag .* f_imag - out_imag = x_imag .* f_real .+ x_real .* f_imag - - # Combine and reshape back - out = vcat(out_real, out_imag) - return reshape(out, (head_dim, seq_len, n_heads, batch)) -end -=# +# compat +forward_inference(model, args...) = model(args...) diff --git a/src/sampling.jl b/src/sampling.jl index ee51d3d..50abebe 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -3,44 +3,35 @@ # Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. """ generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) - + Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. - tkn = llama3_tokenizer() - generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) +```julia +tkn = llama3_tokenizer() +generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010) +``` """ -function generate(model::Transformer{T}, - initial_tokens::AbstractArray{IntT}; - max_new_tokens=100, - sampler::Function=argmax_sampler, - tokenizer_for_printing = nothing, - end_token = 128010, - device = identity) where {T, IntT} - - # Initialize sequence with a new copy of the tokens +function generate( + model::Transformer{T}, + initial_tokens::AbstractArray{<:Integer}; + max_new_tokens=100, + sampler::Function=argmax_sampler, + tokenizer_for_printing = nothing, + end_token = 128010, + device = identity +) where T current_len = length(initial_tokens) - tokens = Vector{IntT}(undef, current_len + max_new_tokens) - tokens[1:current_len] = initial_tokens - # Set up KV caches for all attention layers + tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens)) + for layer in model.layers - layer.attention.cache = KVCache( - T, # eltype - 1, # batch_size - current_len + max_new_tokens, # max possible sequence length - layer.attention.n_kv_heads, - layer.attention.head_dim, - device = device - ) - end - # Process the initial sequence - if current_len > 0 - input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) - logits = forward_inference(model, input_tokens, 0) - start_pos = current_len - else - start_pos = 0 + config!(layer.attention.cache, seq_length = current_len + max_new_tokens) end + + input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) + logits = model(input_tokens, 0) + start_pos = current_len + # Generate new tokens one at a time for _ in 1:max_new_tokens # If sequence is empty or we want to process just the last token @@ -50,22 +41,18 @@ function generate(model::Transformer{T}, input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token end # Get logits for next token - logits = forward_inference(model, input_tokens, start_pos) + logits = model(input_tokens, start_pos) # Sample next token (logits are size vocab × 1 × 1) next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token - if !isnothing(tokenizer_for_printing) - print(decode(tokenizer_for_printing, [next_token])) - end - if next_token == end_token - break - end + !isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token])) + next_token == end_token && break start_pos += 1 end # Clear KV caches for layer in model.layers - layer.attention.cache = nothing + clear!(layer.attention.cache) end return tokens[1:current_len] end diff --git a/src/utils.jl b/src/utils.jl index c8758e3..65e7b5c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,5 @@ +using Accessors + encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1 decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...) @@ -41,7 +43,10 @@ so if you're loading weights from a different source, you might get very poor mo model_weight_paths = ["Llama3_2_1B_instruct/model.safetensors"] #Can be an array of paths if the model is split across multiple files model = load_llama3_from_safetensors(model_weight_paths, config) """ -function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32, add_lora_to = Symbol[], lora_dim = 0) +function load_llama3_from_safetensors( + paths::Vector{String}, config; + T = Float32, add_lora_to = Symbol[], lora_dim = 0, +) config = Dict(config) #Just in case the user passed eg. a JSON3.Object #@assert config[:rope_scaling][:rope_type] == "llama3" #@assert config[:rope_scaling][:low_freq_factor] == 1 @@ -144,40 +149,39 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 end if !isempty(add_lora_to) - #Then load in the current layers: if :Q in add_lora_to for layer in model.layers - layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) + @reset layer.attention.wq = LoRADense(layer.attention.wq, lora_dim) end end if :K in add_lora_to for layer in model.layers - layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) + @reset layer.attention.wk = LoRADense(layer.attention.wk, lora_dim) end end if :V in add_lora_to for layer in model.layers - layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) + @reset layer.attention.wv = LoRADense(layer.attention.wv, lora_dim) end end if :O in add_lora_to for layer in model.layers - layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) + @reset layer.attention.wo = LoRADense(layer.attention.wo, lora_dim) end end if :w1 in add_lora_to for layer in model.layers - layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) + @reset layer.feed_forward.w1 = LoRADense(layer.feed_forward.w1, lora_dim) end end if :w2 in add_lora_to for layer in model.layers - layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) + @reset layer.feed_forward.w2 = LoRADense(layer.feed_forward.w2, lora_dim) end end if :w3 in add_lora_to for layer in model.layers - layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) + @reset layer.feed_forward.w3 = LoRADense(layer.feed_forward.w3, lora_dim) end end end