Skip to content

Commit

Permalink
Merge pull request #17 from MurrellGroup/cache-extend
Browse files Browse the repository at this point in the history
Allowing extension of a previous cache.
  • Loading branch information
murrellb authored Dec 9, 2024
2 parents 5a73614 + 0ad7210 commit da73a7b
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 73 deletions.
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources.HuggingFaceTokenizers]
rev = "main"
url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"
[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}

[extensions]
MetalExt = "Metal"
Expand All @@ -39,6 +38,8 @@ julia = "1.11"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

[targets]
test = ["Test"]
test = ["Test", "Downloads", "JSON3"]
95 changes: 95 additions & 0 deletions examples/scratch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#Pkg.add(["Flux", "JSON3", "UnicodePlots", "StatsBase"])
using Jjama3, Flux, StatsBase, UnicodePlots

#Init a tiny model
model = Transformer(
22, # vocab_size
16*8, # dim
12, # n_layers
8, # n_heads
4, # n_kv_heads
8192, # max_seq_len
16*10, # ff_hidden_dim
)

#Make everything except the RoPE trainable
Jjama3.Flux.@layer Jjama3.Transformer trainable=(tok_embeddings, layers, norm, output)
Jjama3.Flux.@layer Jjama3.Attention trainable=(wq, wk, wv, wo)
Jjama3.Flux.@layer Jjama3.TransformerBlock trainable=(attention, feed_forward, attention_norm, ffn_norm)

#Set up trivial tokenizer
AAs = collect(">ACDEFGHIKLMNPQRSTVWY.")

#Read data, remove X-containing sequences, and adding start and end tokens
data = readlines("abs.txt")
data = [">"*d*"." for d in data if !(occursin("X", d))]

#Train the model
lr = 0.001f0
opt_state = Flux.setup(AdamW(lr), model)
losses = Float32[]
for i in 1:2001
#Prep random batch
train_toks = pad_and_batch(encode.((AAs, ), data[sample(1:length(data), 10, replace=false)]), 22);
#Compute loss and gradients
loss, grads = Flux.withgradient(model) do m
forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:])
end
#Update weights
Flux.update!(opt_state, model, grads[1])
#Monitor
push!(losses, loss)
println(i, " ", loss)
#Monitor sampling
if mod(i, 100) == 1
generate(model, encode(AAs, ">"),
max_new_tokens=500,
tokenizer_for_printing=AAs,
end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22))
println()
display(lineplot(losses, width = 150, height = 30))
end
#Linear learning rate cooldown
if i > 1500
lr = max(lr - 0.001f0/(2000-1500), 0.0000001f0)
Flux.adjust!(opt_state, lr)
end
end

#Test sampling
for i in 1:10
println(">", i)
generate(model, encode(AAs, ">"),
max_new_tokens=500,
tokenizer_for_printing=AAs,
end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22))
println()
end

#Exporting the model
export_model(model, "tinyabllama.safetensors", type_convert = x -> Jjama3.SafeTensors.BFloat16.(x))

#Saving a config so that it loads correctly using the Jjama3 loader
using JSON3
config = Dict()
config[:model_type] = "llama"
config[:vocab_size]= 22
config[:hidden_size] = 16*8
config[:num_hidden_layers] = 12
config[:num_attention_heads] = 8
config[:num_key_value_heads] = 4
config[:max_position_embeddings] = 8192
config[:intermediate_size] = 16*10
config[:rms_norm_eps] = 1f-8
config[:rope_theta] = 500000f0
config[:tie_word_embeddings] = false
open("tinyabllama_config.json", "w") do f
JSON3.pretty(f, JSON3.write(config))
println(f)
end

#Load a trained model and test it
config = JSON3.read(read("tinyabllama_config.json", String))
model_weight_paths = ["tinyabllama.safetensors"]
model = load_llama3_from_safetensors(model_weight_paths, config)
@assert generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
3 changes: 3 additions & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export RoPE
export Attention
export TransformerBlock
export Transformer
export unrope

include("model.jl")
export forward_loss
Expand All @@ -49,5 +50,7 @@ export llama3_assistant_prompt
export smollm2_instruct_prompt
export smollm2_assistant_prompt
export structured_choice
export pad_and_batch
export export_model

end
11 changes: 10 additions & 1 deletion src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ mutable struct KVCache{T,A<:AbstractArray{T,4}}
cache_v::A
end

Flux.@layer KVCache
Base.copy(cache::KVCache) = KVCache(copy(cache.cache_k), copy(cache.cache_v))

Flux.@layer KVCache trainable=()

head_dim(cache::KVCache) = size(cache.cache_k, 1)
seq_length(cache::KVCache) = size(cache.cache_k, 2)
Expand All @@ -21,6 +23,13 @@ function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_
cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0
end

function extend!(cache::KVCache, new_total_length::Int)
old_cache = copy(cache)
config!(cache, seq_length=new_total_length)
cache.cache_k[:, 1:seq_length(old_cache), :, :] .= old_cache.cache_k
cache.cache_v[:, 1:seq_length(old_cache), :, :] .= old_cache.cache_v
end

clear!(cache::KVCache) = config!(cache, seq_length=0)

function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray)
Expand Down
40 changes: 31 additions & 9 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ struct RoPE{A<:AbstractArray}
sin::A
end

Flux.@layer RoPE
Flux.@layer RoPE trainable=()

Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:])

function apply_scaling!(freqs::AbstractVector; scale_factor=8)
#Hard-coded - should move these to the main model struct and grab them from the config.
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192
###
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
for (i, freq) in enumerate(freqs)
Expand All @@ -68,15 +66,15 @@ end

function RoPE(
dim::Int, end_pos::Int;
theta::T=10000f0, use_scaled=true, scale_factor=8,
theta::T=10000f0, use_scaled=true, scale_factor=8, start_pos=0
) where T
freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
use_scaled && apply_scaling!(freqs; scale_factor)
freqs_complex = cis.(T.(0:end_pos-1) * freqs')
freqs_complex = cis.(T.(start_pos:end_pos-1) * freqs')
cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len)
sin = permutedims(imag(freqs_complex), (2, 1))
cos = reshape(cos, (dim÷2, end_pos, 1, 1))
sin = reshape(sin, (dim÷2, end_pos, 1, 1))
cos = reshape(cos, (dim÷2, end_pos - start_pos, 1, 1))
sin = reshape(sin, (dim÷2, end_pos - start_pos, 1, 1))
return RoPE(cos, sin)
end

Expand All @@ -93,6 +91,15 @@ function (rope::RoPE)(x)
)
end

function unrope(rope, x)
head_dim = size(x, 1)
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]
return vcat(
x1 .* rope.cos .+ x2 .* rope.sin,
x2 .* rope.cos .- x1 .* rope.sin
)
end

struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache}
wq::Q
Expand Down Expand Up @@ -195,12 +202,13 @@ end
Flux.@layer TransformerBlock trainable=(attention,)


struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE}
mutable struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE}
tok_embeddings::E
layers::B
norm::N
output::O
rope::R
pos::Int
end

Flux.@layer Transformer trainable=(layers,)
Expand All @@ -218,6 +226,20 @@ function Transformer(
layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers)
norm = RMSNorm(dim, eps=norm_eps)
output = Dense(dim => vocab_size, bias=false)
#This should probably be generated to a sane length, and then extended in the forward pass if needed.
rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor)
Transformer(tok_embeddings, layers, norm, output, rope)
Transformer(tok_embeddings, layers, norm, output, rope, 0)
end


function clear_cache!(model::Transformer)
model.pos = 0
for layer in model.layers
clear!(layer.attention.cache)
end
end

config_cache!(model::Transformer, seq_length) = for layer in model.layers config!(layer.attention.cache, seq_length = seq_length) end

extend_cache!(model::Transformer, seq_length) = for layer in model.layers extend!(layer.attention.cache, seq_length + model.pos) end

67 changes: 34 additions & 33 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,59 @@
#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172

function create_mask(h::AbstractArray{T}) where T<:AbstractFloat
Flux.Zygote.ignore() do
function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractFloat
Flux.ChainRulesCore.ignore_derivatives() do
dim, seqlen, batch = size(h)
mask = similar(h, seqlen, seqlen)
mask .= T(-Inf)
mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup
if precached_size > 0
pad = similar(h, precached_size, seqlen)
pad .= T(0.0)
mask = vcat(pad, mask)
end
return mask
end
end

function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int=0)
function (model::Transformer)(tokens::AbstractArray{Int})
h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch)
rope = model.rope[start_pos+1:start_pos+size(tokens, 1)]
mask = create_mask(h)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
if size(h, 2) == 1
mask = create_mask(h)
else
mask = create_mask(h; precached_size = model.pos)
end
for layer in model.layers
h = layer(h, start_pos, rope, mask)
h = layer(h, model.pos, rope, mask)
end
h = model.norm(h)
output = model.output(h)
model.pos += size(tokens, 1)
return output
end

function masked_agg(ce, mask)
if mask !== nothing
ce = ce .* mask
end
return sum(ce)/sum(mask)
end

function forward_loss(model::Transformer, inputs::AbstractArray,
targets::AbstractArray; ignore_index::Int=-100,
mask = :auto)
seqlen = size(inputs, 1) #(seq_len, batch)
h = model.tok_embeddings(inputs) # (dim, seq_len, batch)
rope = model.rope[1:seqlen]
mask = create_mask(h)
for layer in model.layers
h = layer(h, 0, rope, mask)
targets::AbstractArray; clear_cache = true, loss_mask = nothing)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
clear_cache!(model)
end
end
h = model.norm(h)
logits = model.output(h)
# Need to reshape to (vocab_size, seq_len * batch)
logits_2d = reshape(logits, size(logits,1), :)
targets_1d = reshape(targets, :)
# Mask out ignored indices - will handle this later.
# Note: this is not the autoregressive mask, but the mask for the loss function
#=
mask = targets_1d .!= ignore_index
if any(mask)
loss = Flux.logitcrossentropy(
logits_2d[:, mask],
targets_1d[mask]
)
logits = model(inputs)
vocab_size = size(model.tok_embeddings.weight, 2)
gt = Flux.onehotbatch(targets, 1:vocab_size)
if loss_mask !== nothing
loss = Flux.logitcrossentropy(logits, gt, agg = x -> masked_agg(x, loss_mask))
else
loss = zero(Float32)
loss = Flux.logitcrossentropy(logits, gt)
end
=#
vocab_size = size(model.tok_embeddings.weight, 2)
gt = Flux.onehotbatch(targets_1d, 1:vocab_size)
loss = Flux.logitcrossentropy(logits_2d, gt)
return loss
end

Expand Down
37 changes: 12 additions & 25 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,28 @@ function generate(
sampler::Function=argmax_sampler,
tokenizer_for_printing = nothing,
end_token = 128010,
clear_cache = true,
pos_offset = 0,
device = identity
) where T
current_len = length(initial_tokens)
tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens))

for layer in model.layers
config!(layer.attention.cache, seq_length = current_len + max_new_tokens)
if clear_cache
clear_cache!(model)
config_cache!(model, current_len + max_new_tokens)
else
extend_cache!(model, current_len + max_new_tokens)
end

input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1)
logits = model(input_tokens, 0)
start_pos = current_len

# Generate new tokens one at a time
logits = model(input_tokens)
for _ in 1:max_new_tokens
# If sequence is empty or we want to process just the last token
if start_pos == 0
input_tokens = device(reshape([128001], :, 1)) # Use start of text token if empty
else
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
end
# Get logits for next token
logits = model(input_tokens, start_pos)
# Sample next token (logits are size vocab × 1 × 1)
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
logits = model(input_tokens)
next_token = sampler(logits[:, end, 1])
current_len += 1
tokens[current_len] = next_token
!isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token]))
!isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token], skip_special_tokens = false))
next_token == end_token && break
start_pos += 1
end
# Clear KV caches
for layer in model.layers
clear!(layer.attention.cache)
end
return tokens[1:current_len]
end

end
Loading

0 comments on commit da73a7b

Please sign in to comment.