diff --git a/examples/scratch.jl b/examples/scratch.jl new file mode 100644 index 0000000..3ca48fa --- /dev/null +++ b/examples/scratch.jl @@ -0,0 +1,95 @@ +#Pkg.add(["Flux", "JSON3", "UnicodePlots", "StatsBase"]) +using Jjama3, Flux, StatsBase, UnicodePlots + +#Init a tiny model +model = Transformer( + 22, # vocab_size + 16*8, # dim + 12, # n_layers + 8, # n_heads + 4, # n_kv_heads + 8192, # max_seq_len + 16*10, # ff_hidden_dim +) + +#Make everything except the RoPE trainable +Jjama3.Flux.@layer Jjama3.Transformer trainable=(tok_embeddings, layers, norm, output) +Jjama3.Flux.@layer Jjama3.Attention trainable=(wq, wk, wv, wo) +Jjama3.Flux.@layer Jjama3.TransformerBlock trainable=(attention, feed_forward, attention_norm, ffn_norm) + +#Set up trivial tokenizer +AAs = collect(">ACDEFGHIKLMNPQRSTVWY.") + +#Read data, remove X-containing sequences, and adding start and end tokens +data = readlines("abs.txt") +data = [">"*d*"." for d in data if !(occursin("X", d))] + +#Train the model +lr = 0.001f0 +opt_state = Flux.setup(AdamW(lr), model) +losses = Float32[] +for i in 1:2001 + #Prep random batch + train_toks = pad_and_batch(encode.((AAs, ), data[sample(1:length(data), 10, replace=false)]), 22); + #Compute loss and gradients + loss, grads = Flux.withgradient(model) do m + forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:]) + end + #Update weights + Flux.update!(opt_state, model, grads[1]) + #Monitor + push!(losses, loss) + println(i, " ", loss) + #Monitor sampling + if mod(i, 100) == 1 + generate(model, encode(AAs, ">"), + max_new_tokens=500, + tokenizer_for_printing=AAs, + end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22)) + println() + display(lineplot(losses, width = 150, height = 30)) + end + #Linear learning rate cooldown + if i > 1500 + lr = max(lr - 0.001f0/(2000-1500), 0.0000001f0) + Flux.adjust!(opt_state, lr) + end +end + +#Test sampling +for i in 1:10 + println(">", i) + generate(model, encode(AAs, ">"), + max_new_tokens=500, + tokenizer_for_printing=AAs, + end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22)) + println() +end + +#Exporting the model +export_model(model, "tinyabllama.safetensors", type_convert = x -> Jjama3.SafeTensors.BFloat16.(x)) + +#Saving a config so that it loads correctly using the Jjama3 loader +using JSON3 +config = Dict() +config[:model_type] = "llama" +config[:vocab_size]= 22 +config[:hidden_size] = 16*8 +config[:num_hidden_layers] = 12 +config[:num_attention_heads] = 8 +config[:num_key_value_heads] = 4 +config[:max_position_embeddings] = 8192 +config[:intermediate_size] = 16*10 +config[:rms_norm_eps] = 1f-8 +config[:rope_theta] = 500000f0 +config[:tie_word_embeddings] = false +open("tinyabllama_config.json", "w") do f + JSON3.pretty(f, JSON3.write(config)) + println(f) +end + +#Load a trained model and test it +config = JSON3.read(read("tinyabllama_config.json", String)) +model_weight_paths = ["tinyabllama.safetensors"] +model = load_llama3_from_safetensors(model_weight_paths, config) +@assert generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] \ No newline at end of file diff --git a/src/Jjama3.jl b/src/Jjama3.jl index d985039..2251207 100644 --- a/src/Jjama3.jl +++ b/src/Jjama3.jl @@ -24,6 +24,7 @@ export RoPE export Attention export TransformerBlock export Transformer +export unrope include("model.jl") export forward_loss @@ -49,5 +50,7 @@ export llama3_assistant_prompt export smollm2_instruct_prompt export smollm2_assistant_prompt export structured_choice +export pad_and_batch +export export_model end diff --git a/src/cache.jl b/src/cache.jl index eac4e58..4a070f6 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -5,7 +5,7 @@ end Base.copy(cache::KVCache) = KVCache(copy(cache.cache_k), copy(cache.cache_v)) -Flux.@layer KVCache +Flux.@layer KVCache trainable=() head_dim(cache::KVCache) = size(cache.cache_k, 1) seq_length(cache::KVCache) = size(cache.cache_k, 2) diff --git a/src/layers.jl b/src/layers.jl index 4eeee86..0dc3e4b 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -40,16 +40,14 @@ struct RoPE{A<:AbstractArray} sin::A end -Flux.@layer RoPE +Flux.@layer RoPE trainable=() Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:]) function apply_scaling!(freqs::AbstractVector; scale_factor=8) - #Hard-coded - should move these to the main model struct and grab them from the config. low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 - ### low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor for (i, freq) in enumerate(freqs) @@ -68,15 +66,15 @@ end function RoPE( dim::Int, end_pos::Int; - theta::T=10000f0, use_scaled=true, scale_factor=8, + theta::T=10000f0, use_scaled=true, scale_factor=8, start_pos=0 ) where T freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim)) use_scaled && apply_scaling!(freqs; scale_factor) - freqs_complex = cis.(T.(0:end_pos-1) * freqs') + freqs_complex = cis.(T.(start_pos:end_pos-1) * freqs') cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len) sin = permutedims(imag(freqs_complex), (2, 1)) - cos = reshape(cos, (dim÷2, end_pos, 1, 1)) - sin = reshape(sin, (dim÷2, end_pos, 1, 1)) + cos = reshape(cos, (dim÷2, end_pos - start_pos, 1, 1)) + sin = reshape(sin, (dim÷2, end_pos - start_pos, 1, 1)) return RoPE(cos, sin) end @@ -93,6 +91,15 @@ function (rope::RoPE)(x) ) end +function unrope(rope, x) + head_dim = size(x, 1) + x1 = x[1:head_dim÷2, :, :, :] + x2 = x[head_dim÷2+1:end, :, :, :] + return vcat( + x1 .* rope.cos .+ x2 .* rope.sin, + x2 .* rope.cos .- x1 .* rope.sin + ) +end struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache} wq::Q @@ -219,10 +226,12 @@ function Transformer( layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers) norm = RMSNorm(dim, eps=norm_eps) output = Dense(dim => vocab_size, bias=false) + #This should probably be generated to a sane length, and then extended in the forward pass if needed. rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor) Transformer(tok_embeddings, layers, norm, output, rope, 0) end + function clear_cache!(model::Transformer) model.pos = 0 for layer in model.layers @@ -232,4 +241,5 @@ end config_cache!(model::Transformer, seq_length) = for layer in model.layers config!(layer.attention.cache, seq_length = seq_length) end -extend_cache!(model::Transformer, seq_length) = for layer in model.layers extend!(layer.attention.cache, seq_length + model.pos) end \ No newline at end of file +extend_cache!(model::Transformer, seq_length) = for layer in model.layers extend!(layer.attention.cache, seq_length + model.pos) end + diff --git a/src/utils.jl b/src/utils.jl index 65e7b5c..e26c2d0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,6 +3,17 @@ using Accessors encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1 decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...) +#Trivial Char tokenizers: +encode(chars::Vector{Char}, str::String) = [findfirst(==(c), chars) for c in str] +decode(chars::Vector{Char}, enc::Vector{Int}; skip_special_tokens=false) = String([chars[i] for i in enc]) + +#For training: +function pad_and_batch(seqs, pad_token) + max_len = maximum(length.(seqs)) + padded = [vcat(s, fill(pad_token, max_len - length(s))) for s in seqs] + cat(padded..., dims = 2) +end + #https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/ function llama3_instruct_prompt(tokenizer,system_prompt, user_prompt) @@ -237,3 +248,41 @@ function structured_choice(choices::Vector{String}, vocab::Vector{String}, end_t end return choice_sampler end + + +function export_model(model, output_path; type_convert = identity) + weights = Dict{String,AbstractArray}() + weights["model.embed_tokens.weight"] = type_convert(model.tok_embeddings.weight') + weights["lm_head.weight"] = type_convert(model.output.weight) + weights["model.norm.weight"] = type_convert(model.norm.weight) + for (i, layer) in enumerate(model.layers) + prefix = "model.layers.$(i-1)" + weights["$prefix.self_attn.q_proj.weight"] = type_convert(layer.attention.wq.weight) + weights["$prefix.self_attn.k_proj.weight"] = type_convert(layer.attention.wk.weight) + weights["$prefix.self_attn.v_proj.weight"] = type_convert(layer.attention.wv.weight) + weights["$prefix.self_attn.o_proj.weight"] = type_convert(layer.attention.wo.weight) + if layer.attention.wq.bias + weights["$prefix.self_attn.q_proj.bias"] = type_convert(layer.attention.wq.bias) + weights["$prefix.self_attn.k_proj.bias"] = type_convert(layer.attention.wk.bias) + weights["$prefix.self_attn.v_proj.bias"] = type_convert(layer.attention.wv.bias) + end + weights["$prefix.mlp.gate_proj.weight"] = type_convert(layer.feed_forward.w1.weight) + weights["$prefix.mlp.down_proj.weight"] = type_convert(layer.feed_forward.w2.weight) + weights["$prefix.mlp.up_proj.weight"] = type_convert(layer.feed_forward.w3.weight) + weights["$prefix.input_layernorm.weight"] = type_convert(layer.attention_norm.weight) + weights["$prefix.post_attention_layernorm.weight"] = type_convert(layer.ffn_norm.weight) + end + SafeTensors.serialize(output_path, weights) + return nothing +end + +#= +#Example for how to save a model in safetensors format +julia> using Random, BFloat16s, SafeTensors +julia> weights = Dict("W"=>randn(BFloat16, 3, 5), "b"=>rand(BFloat16, 3)) +Dict{String, Array{BFloat16}} with 2 entries: + "W" => [0.617188 0.695312 … 0.390625 -2.0; -0.65625 -0.617188 … 0.652344 0.244141; 0.226562 2.70312 … -0.174805 -0.7773… + "b" => [0.111816, 0.566406, 0.283203] +julia> f = tempname(); +julia> SafeTensors.serialize(f, weights) +=# \ No newline at end of file