Skip to content

Commit

Permalink
Adding samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 25, 2024
1 parent 45b431e commit 6288b70
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 66 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.0.0-DEV"
BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89"
Expand All @@ -15,6 +16,9 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources]
HuggingFaceTokenizers = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"

[extensions]
MetalExt = "Metal"

Expand Down
9 changes: 9 additions & 0 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,13 @@ function NNlib.batched_mul(a::MtlArray, b::MtlArray)
return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...)
end

function NNlib.PermutedDimsArray(a::MtlArray, perm)
return permutedims(a, perm)
end

function NNlib.batched_transpose(a::MtlArray)
dims = size(a)
return permutedims(a, (2,1,3:length(dims)...))
end

end
26 changes: 18 additions & 8 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
module Jjama3

using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib
using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib
import HuggingFaceTokenizers

tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained
tokenizer_from_file = HuggingFaceTokenizers.from_file
Tokenizer = HuggingFaceTokenizers.Tokenizer

include("model.jl")
include("utils.jl")
include("sampling.jl")
include("tokenizers.jl")

export load_huggingface_tokenizer_and_encoder,
load_llama321B_from_safetensors,
export load_llama321B_from_safetensors,
load_llama3_from_safetensors,
llama3_tokenizer,
assistant_prompt,
format_llama32_instruction_prompt,
generate,
forward_loss,
forward_inference,
top_pk_sampler,
argmax_sampler,
load_huggingface_tokenizer_and_encoder
top_n_sigma_sampler,
min_p_sampler,
tokenizer_from_repo,
tokenizer_from_file,
encode,
decode,
Tokenizer,
llama3_instruct_prompt,
llama3_assistant_prompt,
smollm2_instruct_prompt,
smollm2_assistant_prompt

end
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ function create_mask(h::AbstractArray)
T = eltype(h)
mask .= T(-Inf)
#mask = triu(mask, 1)
mask = tril(mask, -1)
mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup
return mask
end

Expand Down
45 changes: 40 additions & 5 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,58 @@ end

top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(logits; p, k, device)

# https://arxiv.org/pdf/2411.07641
function top_n_sigma_sampler(logits::AbstractVector{T}; temperature::T = 1.0f0, n::T = 1.0f0, device = identity) where T
scaled_logits = logits ./ temperature
M = maximum(scaled_logits)
σ = std(scaled_logits)
threshold = M - n * σ
mask = scaled_logits .>= threshold
masked_logits = copy(scaled_logits)
masked_logits[.!mask] .= -Inf
probs = device(Jjama3.softmax(masked_logits))
return sample(1:length(probs), Weights(probs))
end

top_n_sigma_sampler(; temperature = 1.0f0, n = 1.0f0, device = identity) = logits -> top_n_sigma_sampler(logits; temperature, n, device)

#https://arxiv.org/pdf/2407.01082
function min_p_sampler(logits::AbstractVector{T}; pbase::T = 0.5f0, device = identity) where T
probs = device(Jjama3.softmax(logits))
pmax = maximum(probs)
pscaled = pbase * pmax
mask = probs .>= pscaled
if !any(mask)
mask[argmax(probs)] = true
end
masked_probs = copy(probs)
masked_probs[.!mask] .= zero(T)
normalization = sum(masked_probs)
if normalization > 0
masked_probs ./= normalization
end
return sample(1:length(probs), Weights(masked_probs))
end

min_p_sampler(; pbase = 0.5f0, device = identity) = logits -> min_p_sampler(logits; pbase, device)

# This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence
# to the sampling of each token. But when I try and fix it, the model gets slightly dumber.
# Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time.
"""
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010)
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010)
Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now.
Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU.
tkn = llama3_tokenizer()
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010)
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010)
"""
function generate(model::Transformer{T},
initial_tokens::AbstractArray{IntT};
max_new_tokens=100,
sampler::Function=argmax_sampler,
encoder_for_printing = nothing,
tokenizer_for_printing = nothing,
end_token = 128010,
device = identity) where {T, IntT}

Expand Down Expand Up @@ -77,8 +112,8 @@ function generate(model::Transformer{T},
next_token = sampler(logits[:, end, 1])
current_len += 1
tokens[current_len] = next_token
if !isnothing(encoder_for_printing)
print(encoder_for_printing.decode([next_token]))
if !isnothing(tokenizer_for_printing)
print(decode(tokenizer_for_printing, [next_token]))
end
if next_token == end_token
break
Expand Down
80 changes: 28 additions & 52 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,64 +1,34 @@
"""
tkn = llama3_tokenizer()
encode(tkn::Tokenizer, str) = HuggingFaceTokenizers.encode(tkn, str).ids .+ 1
decode(tkn::Tokenizer, ids) = HuggingFaceTokenizers.decode(tkn, ids .- 1)

Load the tokenizer for Llama3. This seems to work, but I have not checked if there are some different edge-cases, or missing tokens relative to the original tokenizer (besides the special tokens we hackily include).

tkn = llama3_tokenizer()
tkn.encode("What is the capital of France?")
tkn.decode([10, 2, 5, 99])
"""
llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base")
#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
function llama3_instruct_prompt(tokenizer,system_prompt, user_prompt)
str = """<|start_header_id|>system<|end_header_id|>
$system_prompt
<|eot_id|><|start_header_id|>user<|end_header_id|>
$(user_prompt)<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
return encode(tokenizer, str)
end

"""
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
Format a prompt for use with Llama3.2's instruction format, with a simple "You are a helpful assistant" system prompt.
tkn = llama3_tokenizer()
prompt = assistant_prompt("What is the capital of France?", tkn)
prompt = assistant_prompt(tkn, "What is the capital of France?")
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
"""
assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a helpful assistant\n", prompt, tkn);

llama3_assistant_prompt(tokenizer, prompt) = llama3_instruct_prompt(tokenizer,"\nYou are a helpful assistant\n", prompt);

#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
"""
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles.
tkn = llama3_tokenizer()
prompt = format_llama32_instruction_prompt("\\nYou are a helpful assistant\\n", "What is the capital of France?", tkn)
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
"""
function format_llama32_instruction_prompt(sys_prompt, user_prompt, tokenizer)
prompt = [128001, 128007] #begin_of_text, start_header_id
prompt = vcat(prompt, tokenizer.encode("system"))
push!(prompt, 128008) #end_header_id
prompt = vcat(prompt, tokenizer.encode(sys_prompt))
prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id
prompt = vcat(prompt, tokenizer.encode("user"))
push!(prompt, 128008) #end_header_id
prompt = vcat(prompt, tokenizer.encode("\n"))
prompt = vcat(prompt, tokenizer.encode(user_prompt))
prompt = vcat(prompt, [128010, 128007]) #eot_id, start_header_id
prompt = vcat(prompt, tokenizer.encode("assistant"))
push!(prompt, 128008) #end_header_id
return prompt
function smollm2_instruct_prompt(tokenizer, system_prompt, user_prompt)
str = """<|im_start|>system\n$(system_prompt)<|im_end|>\n<|im_start|>user\n$(user_prompt)<|im_end|>\n"""
return encode(tokenizer, str)
end

#These have already been incremented by 1 to account for Julia's 1-indexing
special_tokens = Dict(
"<|begin_of_text|>" => 128001,
"<|end_of_text|>" => 128002,
"<|start_header_id|>" => 128007,
"<|end_header_id|>" => 128008,
"<|eot_id|>" => 128010,
"<|finetune_right_pad_id|>" => 128005,
"<|python_tag|>" => 128011
)
smollm2_assistant_prompt(tokenizer, prompt) = smollm2_instruct_prompt(tokenizer, "You are a helpful AI assistant named SmolLM, trained by Hugging Face", prompt);

#[ "<|start_header_id|>user<|end_header_id|>\n\nGiven the following question and four candidate answers (A, B, C and D), choose the best answer.\nQuestion: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation?\nA. Planetary density will decrease.\nB. Planetary years will become longer.\nC. Planetary days will become shorter.\nD. Planetary gravity will become stronger.\nYour response should end with \"The best answer is [the_answer_letter]\" where the [the_answer_letter] is one of A, B, C or D.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe best answer is" ]


"""
Expand All @@ -75,12 +45,18 @@ so if you're loading weights from a different source, you might get very poor mo
"""
function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32)
config = Dict(config) #Just in case the user passed eg. a JSON3.Object
@assert config[:rope_scaling][:rope_type] == "llama3"
@assert config[:rope_scaling][:low_freq_factor] == 1
@assert config[:rope_scaling][:high_freq_factor] == 4
@assert config[:rope_scaling][:original_max_position_embeddings] == 8192
#@assert config[:rope_scaling][:rope_type] == "llama3"
#@assert config[:rope_scaling][:low_freq_factor] == 1
#@assert config[:rope_scaling][:high_freq_factor] == 4
#@assert config[:rope_scaling][:original_max_position_embeddings] == 8192

# Create model with config parameters from the JSON
scale_factor = 1f0
if haskey(config, :rope_scaling)
if !isnothing(config[:rope_scaling])
scale_factor = config[:rope_scaling][:factor]
end
end
model = Transformer(
config[:vocab_size], # vocab_size
config[:hidden_size], # dim (hidden_size)
Expand All @@ -92,7 +68,7 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32
norm_eps=T(config[:rms_norm_eps]), # rms_norm_eps
rope_theta=T(config[:rope_theta]), # rope_theta
use_scaled_rope=true, # Using scaled RoPE based on the config
scale_factor=config[:rope_scaling][:factor] # scale_factor
scale_factor=scale_factor # scale_factor
)

for path in paths # Process one file at a time
Expand Down

0 comments on commit 6288b70

Please sign in to comment.