Skip to content

Commit

Permalink
Forward pass working on Metal. Sampling slow though.
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 14, 2024
1 parent beebe36 commit 56c6452
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 38 deletions.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[extensions]
MetalExt = "Metal"

[compat]
julia = "1.9"
Expand Down
18 changes: 18 additions & 0 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module MetalExt

#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling.
#It seems that each single Metal call has some constant overhead that kills it.

using Metal, Jjama3.NNlib

function NNlib.batched_mul(a::MtlArray, b::MtlArray)
a_shape = size(a)
b_shape = size(b)
a_reshaped = reshape(a, a_shape[1], a_shape[2], :)
b_reshaped = reshape(b, b_shape[1], b_shape[2], :)
res = Metal.zeros(a_shape[1], b_shape[2], size(a_reshaped)[3])
Metal.MPS.matmul!(res, a_reshaped,b_reshaped)
return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...)
end

end
4 changes: 2 additions & 2 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module Jjama3

using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra
using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib

include("model.jl")
include("utils.jl")
include("sampling.jl")

export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference
export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler

end
47 changes: 22 additions & 25 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#Important about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172
#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172

function apply_scaling(freqs::AbstractVector; scale_factor=8)
#Hard-coded - I should move these to the main model struct and grab them from the config.
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
old_context_len = 8192
###
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = similar(freqs)
Expand Down Expand Up @@ -38,7 +40,7 @@ function precompute_freqs_cis(dim::Int, end_pos::Int;
end


#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509
#Note about Huggingface weights and rotary embeddings: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509
#Use this one if you're using the Hugging Face weights.
function apply_rotary_emb(x, freqs_cis)
# x is (head_dim, seq_len, n_heads, batch)
Expand Down Expand Up @@ -118,7 +120,6 @@ function repeat_kv(x::AbstractArray, n_rep::Int)
head_dim, seq_len, n_kv_heads, batch = size(x)
x_expanded = reshape(x, (head_dim, seq_len, 1, n_kv_heads, batch))
x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1)
x_repeated = repeat(x_expanded, 1, 1, n_rep, 1, 1)
return reshape(x_repeated, (head_dim, seq_len, n_rep * n_kv_heads, batch))
end

Expand Down Expand Up @@ -167,37 +168,34 @@ function (attn::Attention)(x::AbstractArray{T}, start_pos::Int, freqs_cis, mask=
xk = reshape(xk, (attn.head_dim, attn.n_kv_heads, seqlen, batch))
xv = reshape(xv, (attn.head_dim, attn.n_kv_heads, seqlen, batch))

#Can we switch to keeping this in its shape, aplpying rot emb in that shape, and only permuting one
#Some GPUs don't like PermutedDimsArray
#xq = PermutedDimsArray(xq, (1,3,2,4)) #No idea if this is faster...
#xk = PermutedDimsArray(xk, (1,3,2,4))
#xv = PermutedDimsArray(xv, (1,3,2,4))
xq = permutedims(xq, (1,3,2,4))
xk = permutedims(xk, (1,3,2,4))
xv = permutedims(xv, (1,3,2,4))

# Apply RoPE
xq_rope = apply_rotary_emb(xq, freqs_cis)
xk_rope = apply_rotary_emb(xk, freqs_cis)
# Handle KV cache

if !isnothing(attn.cache)
xk_rope, xv = update_kv_cache(attn.cache, start_pos, xk_rope, xv)
end

# Apply GQA via repeat_kv
xk_rope = repeat_kv(xk_rope, attn.n_rep)
xv = repeat_kv(xv, attn.n_rep)
# Reshape for attention - dummy dim is seqlength, which isn't the length of the seq when using the KV cache

xq_for_attn = reshape(xq_rope, attn.head_dim, :, attn.n_heads * batch)
xk_for_attn = reshape(xk_rope, attn.head_dim, :, attn.n_heads * batch)
xv_for_attn = reshape(xv, attn.head_dim, :, attn.n_heads * batch)
# Compute attention scores

scores = batched_mul(
0f0 .+ permutedims(xq_for_attn, (2,1,3)), # (seqlen, head_dim, batch*heads)
0f0 .+xk_for_attn # (head_dim, seqlen, batch*heads)
) ./ sqrt(T(attn.head_dim))
if !isnothing(mask)
#@show typeof(scores)
#@show typeof(mask)
scores = scores .+ mask
end
#len: 3, len: 3, headsxbatch: 8
Expand Down Expand Up @@ -298,28 +296,34 @@ function forward_inference(model::Transformer{T}, tokens::AbstractArray{Int}, st
return output
end

function create_mask(h::AbstractArray)
embeddim, seqlen, batch = size(h)
mask = similar(h, seqlen, seqlen)
T = eltype(h)
mask .= T(-Inf)
mask = triu(mask, 1)
return mask
end

function forward_loss(model::Transformer{T}, inputs::AbstractArray,
targets::AbstractArray; ignore_index::Int=-100,
mask = triu(fill(T(-Inf), (size(inputs, 1), size(inputs, 1))),1)) where T
mask = :auto) where T
seqlen = size(inputs, 1) #(seq_len, batch)
h = model.tok_embeddings(inputs) # (dim, seq_len, batch)

cos, sin = model.freqs_cis #@show size(cos) #(head_dim/2, max_RoPE, 1, 1)
freqs_cis = (cos[:,1:seqlen,:,:], sin[:,1:seqlen,:,:])

# Forward through layers (start_pos = 0 disables KV caching)
if mask == :auto
mask = create_mask(h)
end
for layer in model.layers
h = layer(h, 0, freqs_cis, mask)
end
h = model.norm(h)
logits = model.output(h)
#@show size(logits)
# Need to reshape to (vocab_size, seq_len * batch)
logits_2d = reshape(logits, size(logits,1), :)
#@show [argmax(logits_2d[:,i]) for i in 1:size(logits_2d,2)]
#@show size(logits_2d)
targets_1d = reshape(targets, :)
#@show size(targets_1d)
# Mask out ignored indices - will handle this later.
# Note: this is not the autoregressive mask, but the mask for the loss function
#=
Expand All @@ -335,18 +339,11 @@ function forward_loss(model::Transformer{T}, inputs::AbstractArray,
=#
vocab_size = size(model.tok_embeddings.weight, 2)
gt = Flux.onehotbatch(targets_1d, 1:vocab_size)
#@show size(gt)
loss = Flux.logitcrossentropy(logits_2d, gt)
#@show Flux.logitcrossentropy(logits_2d, gt, agg = identity)
return loss
end







#https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509
#=
#Use this one if you're using the original Meta weights.
Expand Down
41 changes: 34 additions & 7 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@

function default_sampler(logits::AbstractVector)
return argmax(logits)
function argmax_sampler(logits::AbstractVector; device = identity)
return argmax(device(logits))
end

argmax_sampler(; device = identity) = logits -> argmax_sampler(logits; device = device)

function top_pk_sampler(logits::AbstractVector; p = 0.5f0, k = 5, device = identity)
probs = device(Jjama3.softmax(logits))
perm = partialsortperm(probs, 1:k, rev=true)
sorted_probs = probs[perm]
cumsum_probs = cumsum(sorted_probs)
if cumsum_probs[1] > p
return perm[1]
else
cutoff = findlast(cumsum_probs .< p)
return sample(perm[1:cutoff], Weights(sorted_probs[1:cutoff]))
end
end

top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device)

# This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence
# to the sampling of each token. But when I try and fix it, the model gets slightly dumber.
# Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time.
"""
Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now.
Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU.
tkn = llama3_tokenizer()
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010)
"""
function generate(model::Transformer{T},
initial_tokens::AbstractArray{IntT};
max_new_tokens=100,
sampler::Function=default_sampler,
sampler::Function=argmax_sampler,
encoder_for_printing = nothing,
end_token = 128010,
device = identity) where {T, IntT}
Expand All @@ -28,7 +55,7 @@ function generate(model::Transformer{T},
end
# Process the initial sequence
if current_len > 0
input_tokens = reshape(initial_tokens, :, 1) # (seq_len, batch=1)
input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1)
logits = forward_inference(model, input_tokens, 0)
start_pos = current_len
else
Expand All @@ -38,14 +65,14 @@ function generate(model::Transformer{T},
for _ in 1:max_new_tokens
# If sequence is empty or we want to process just the last token
if start_pos == 0
input_tokens = reshape([128001], :, 1) # Use start of text token if empty
input_tokens = device(reshape([128001], :, 1)) # Use start of text token if empty
else
input_tokens = reshape([tokens[current_len]], :, 1) # Just the last token
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
end
# Get logits for next token
logits = forward_inference(model, input_tokens, start_pos)
# Sample next token (logits are size vocab × 1 × 1)
next_token = sampler(vec(logits[:, end, 1]))
next_token = sampler(logits[:, end, 1])
current_len += 1
tokens[current_len] = next_token
if !isnothing(encoder_for_printing)
Expand Down
22 changes: 18 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ Format a prompt for use with Llama3.2's instruction format, with a simple "You a
assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn);


#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
"""
Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles.
tkn = llama3_tokenizer()
prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn)
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
"""
function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer)
#begin_of_text, start_header_id
prompt = [128001, 128007] #plus 1 because Julia is 1-indexed
function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer)
prompt = [128001, 128007] #begin_of_text, start_header_id
prompt = vcat(prompt, tokenizer.encode("system"))
push!(prompt, 128008) #end_header_id
prompt = vcat(prompt, tokenizer.encode(sys_prompt))
Expand All @@ -28,12 +28,25 @@ function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer)
push!(prompt, 128008) #end_header_id
prompt = vcat(prompt, tokenizer.encode("\n"))
prompt = vcat(prompt, tokenizer.encode(user_prompt))
prompt = vcat(prompt, [128009, 128007]) #eot_id, start_header_id
prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id
prompt = vcat(prompt, tokenizer.encode("assistant"))
push!(prompt, 128008) #end_header_id
return prompt
end

#These have already been incremented by 1 to account for Julia's 1-indexing
special_tokens = Dict(
"<|begin_of_text|>" => 128001,
"<|end_of_text|>" => 128002,
"<|start_header_id|>" => 128007,
"<|end_header_id|>" => 128008,
"<|eot_id|>" => 128010,
"<|finetune_right_pad_id|>" => 128005,
"<|python_tag|>" => 128011
)

#[ "<|start_header_id|>user<|end_header_id|>\n\nGiven the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation?\nA. Planetary density will decrease.\nB. Planetary years will become longer.\nC. Planetary days will become shorter.\nD. Planetary gravity will become stronger.\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe best answer is" ]


"""
Load a Llama3 model from a set of Huggingface safetensors files, and the config.json file.
Expand Down Expand Up @@ -131,3 +144,4 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32
end

load_llama3_from_safetensors(path::String, config; T = Float32) = load_llama3_from_safetensors([path], config; T = T)

0 comments on commit 56c6452

Please sign in to comment.