diff --git a/Project.toml b/Project.toml index b833e42..76755b9 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,15 @@ BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[weakdeps] +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + +[extensions] +MetalExt = "Metal" [compat] julia = "1.9" diff --git a/ext/MetalExt.jl b/ext/MetalExt.jl new file mode 100644 index 0000000..8659a47 --- /dev/null +++ b/ext/MetalExt.jl @@ -0,0 +1,18 @@ +module MetalExt + +#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling. +#It seems that each single Metal call has some constant overhead that kills it. + +using Metal, Jjama3.NNlib + +function NNlib.batched_mul(a::MtlArray, b::MtlArray) + a_shape = size(a) + b_shape = size(b) + a_reshaped = reshape(a, a_shape[1], a_shape[2], :) + b_reshaped = reshape(b, b_shape[1], b_shape[2], :) + res = Metal.zeros(a_shape[1], b_shape[2], size(a_reshaped)[3]) + Metal.MPS.matmul!(res, a_reshaped,b_reshaped) + return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...) +end + +end diff --git a/src/Jjama3.jl b/src/Jjama3.jl index a59ac18..ac0cf52 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -1,11 +1,11 @@ module Jjama3 -using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra +using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib include("model.jl") include("utils.jl") include("sampling.jl") -export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference +export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler end diff --git a/src/model.jl b/src/model.jl index 6637282..1078808 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,9 +1,11 @@ -#Important about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 +#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 function apply_scaling(freqs::AbstractVector; scale_factor=8) + #Hard-coded - I should move these to the main model struct and grab them from the config. low_freq_factor = 1 high_freq_factor = 4 - old_context_len = 8192 # original llama3 length + old_context_len = 8192 + ### low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor new_freqs = similar(freqs) @@ -38,7 +40,7 @@ function precompute_freqs_cis(dim::Int, end_pos::Int; end -#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 +#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 #Use this one if you're using the Hugging Face weights. function apply_rotary_emb(x, freqs_cis) # x is (head_dim, seq_len, n_heads, batch) @@ -118,7 +120,6 @@ function repeat_kv(x::AbstractArray, n_rep::Int) head_dim, seq_len, n_kv_heads, batch = size(x) x_expanded = reshape(x, (head_dim, seq_len, 1, n_kv_heads, batch)) x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) - x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1) return reshape(x_repeated, (head_dim, seq_len, n_rep * n_kv_heads, batch)) end @@ -167,7 +168,7 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch)) - #Can we switch to keeping this in its shape, aplpying rot emb in that shape, and only permuting one + #Some GPUs don't like PermutedDimsArray #xq = PermutedDimsArray(xq, (1,3,2,4)) #No idea if this is faster... #xk = PermutedDimsArray(xk, (1,3,2,4)) #xv = PermutedDimsArray(xv, (1,3,2,4)) @@ -175,29 +176,26 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask= xk = permutedims(xk, (1,3,2,4)) xv = permutedims(xv, (1,3,2,4)) - # Apply RoPE xq_rope = apply_rotary_emb(xq, freqs_cis) xk_rope = apply_rotary_emb(xk, freqs_cis) - # Handle KV cache + if !isnothing(attn.cache) xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv) end + # Apply GQA via repeat_kv xk_rope = repeat_kv(xk_rope, attn.n_rep) xv = repeat_kv(xv, attn.n_rep) - # Reshape for attention - dummy dim is seqlength, which isn't the length of the seq when using the KV cache + xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch) xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch) xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch) - # Compute attention scores scores = batched_mul( 0f0 .+ permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads) 0f0 .+xk_for_attn # (head_dim, seqlen, batch*heads) ) ./ sqrt(T(attn.head_dim)) if !isnothing(mask) - #@show typeof(scores) - #@show typeof(mask) scores = scores .+ mask end #len: 3, len: 3, headsxbatch: 8 @@ -298,28 +296,34 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st return output end +function create_mask(h::AbstractArray) + embeddim, seqlen, batch = size(h) + mask = similar(h, seqlen, seqlen) + T = eltype(h) + mask .= T(-Inf) + mask = triu(mask, 1) + return mask +end + function forward_loss(model::Transformer{T}, inputs::AbstractArray, targets::AbstractArray; ignore_index::Int=-100, - mask = triu(fill(T(-Inf), (size(inputs, 1), size(inputs, 1))),1)) where T + mask = :auto) where T seqlen = size(inputs, 1) #(seq_len, batch) h = model.tok_embeddings(inputs) # (dim, seq_len, batch) - cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1) freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:]) - # Forward through layers (start_pos = 0 disables KV caching) + if mask == :auto + mask = create_mask(h) + end for layer in model.layers h = layer(h, 0, freqs_cis, mask) end h = model.norm(h) logits = model.output(h) - #@show size(logits) # Need to reshape to (vocab_size, seq_len * batch) logits_2d = reshape(logits, size(logits,1), :) - #@show [argmax(logits_2d[:,i]) for i in 1:size(logits_2d,2)] - #@show size(logits_2d) targets_1d = reshape(targets, :) - #@show size(targets_1d) # Mask out ignored indices - will handle this later. # Note: this is not the autoregressive mask, but the mask for the loss function #= @@ -335,18 +339,11 @@ function forward_loss(model::Transformer{T}, inputs::AbstractArray, =# vocab_size = size(model.tok_embeddings.weight, 2) gt = Flux.onehotbatch(targets_1d, 1:vocab_size) - #@show size(gt) loss = Flux.logitcrossentropy(logits_2d, gt) - #@show Flux.logitcrossentropy(logits_2d, gt, agg = identity) return loss end - - - - - #https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509 #= #Use this one if you're using the original Meta weights. diff --git a/src/sampling.jl b/src/sampling.jl index 6a13e5c..9833da3 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -1,12 +1,39 @@ -function default_sampler(logits::AbstractVector) - return argmax(logits) +function argmax_sampler(logits::AbstractVector; device = identity) + return argmax(device(logits)) end +argmax_sampler(; device = identity) = logits -> argmax_sampler(logits; device = device) + +function top_pk_sampler(logits::AbstractVector; p = 0.5f0, k = 5, device = identity) + probs = device(Jjama3.softmax(logits)) + perm = partialsortperm(probs, 1:k, rev=true) + sorted_probs = probs[perm] + cumsum_probs = cumsum(sorted_probs) + if cumsum_probs[1] > p + return perm[1] + else + cutoff = findlast(cumsum_probs .< p) + return sample(perm[1:cutoff], Weights(sorted_probs[1:cutoff])) + end +end + +top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device) + +# This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence +# to the sampling of each token. But when I try and fix it, the model gets slightly dumber. +# Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time. +""" +Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now. +Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU. + + tkn = llama3_tokenizer() + generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010) +""" function generate(model::Transformer{T}, initial_tokens::AbstractArray{IntT}; max_new_tokens=100, - sampler::Function=default_sampler, + sampler::Function=argmax_sampler, encoder_for_printing = nothing, end_token = 128010, device = identity) where {T, IntT} @@ -28,7 +55,7 @@ function generate(model::Transformer{T}, end # Process the initial sequence if current_len > 0 - input_tokens = reshape(initial_tokens, :, 1) # (seq_len, batch=1) + input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1) logits = forward_inference(model, input_tokens, 0) start_pos = current_len else @@ -38,14 +65,14 @@ function generate(model::Transformer{T}, for _ in 1:max_new_tokens # If sequence is empty or we want to process just the last token if start_pos == 0 - input_tokens = reshape([128001], :, 1) # Use start of text token if empty + input_tokens = device(reshape([128001], :, 1)) # Use start of text token if empty else - input_tokens = reshape([tokens[current_len]], :, 1) # Just the last token + input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token end # Get logits for next token logits = forward_inference(model, input_tokens, start_pos) # Sample next token (logits are size vocab × 1 × 1) - next_token = sampler(vec(logits[:, end, 1])) + next_token = sampler(logits[:, end, 1]) current_len += 1 tokens[current_len] = next_token if !isnothing(encoder_for_printing) diff --git a/src/utils.jl b/src/utils.jl index f93046f..aec8cc1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,7 @@ Format a prompt for use with Llama3.2's instruction format, with a simple "You a assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn); +#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ """ Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles. @@ -17,9 +18,8 @@ Format a prompt for use with Llama3.2's instruction format, injecting the system prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn) generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn) """ -function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) - #begin_of_text, start_header_id - prompt = [128001, 128007] #plus 1 because Julia is 1-indexed +function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) + prompt = [128001, 128007] #begin_of_text, start_header_id prompt = vcat(prompt, tokenizer.encode("system")) push!(prompt, 128008) #end_header_id prompt = vcat(prompt, tokenizer.encode(sys_prompt)) @@ -28,12 +28,25 @@ function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer) push!(prompt, 128008) #end_header_id prompt = vcat(prompt, tokenizer.encode("\n")) prompt = vcat(prompt, tokenizer.encode(user_prompt)) - prompt = vcat(prompt, [128009, 128007]) #eot_id, start_header_id + prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id prompt = vcat(prompt, tokenizer.encode("assistant")) push!(prompt, 128008) #end_header_id return prompt end +#These have already been incremented by 1 to account for Julia's 1-indexing +special_tokens = Dict( + "<|begin_of_text|>" => 128001, + "<|end_of_text|>" => 128002, + "<|start_header_id|>" => 128007, + "<|end_header_id|>" => 128008, + "<|eot_id|>" => 128010, + "<|finetune_right_pad_id|>" => 128005, + "<|python_tag|>" => 128011 +) + +#[ "<|start_header_id|>user<|end_header_id|>\n\nGiven the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation?\nA. Planetary density will decrease.\nB. Planetary years will become longer.\nC. Planetary days will become shorter.\nD. Planetary gravity will become stronger.\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe best answer is" ] + """ Load a Llama3 model from a set of Huggingface safetensors files, and the config.json file. @@ -131,3 +144,4 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32 end load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T) +