Skip to content

Commit

Permalink
Train AA model from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Dec 8, 2024
1 parent df14368 commit f27607c
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 9 deletions.
95 changes: 95 additions & 0 deletions examples/scratch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#Pkg.add(["Flux", "JSON3", "UnicodePlots", "StatsBase"])
using Jjama3, Flux, StatsBase, UnicodePlots

#Init a tiny model
model = Transformer(
22, # vocab_size
16*8, # dim
12, # n_layers
8, # n_heads
4, # n_kv_heads
8192, # max_seq_len
16*10, # ff_hidden_dim
)

#Make everything except the RoPE trainable
Jjama3.Flux.@layer Jjama3.Transformer trainable=(tok_embeddings, layers, norm, output)
Jjama3.Flux.@layer Jjama3.Attention trainable=(wq, wk, wv, wo)
Jjama3.Flux.@layer Jjama3.TransformerBlock trainable=(attention, feed_forward, attention_norm, ffn_norm)

#Set up trivial tokenizer
AAs = collect(">ACDEFGHIKLMNPQRSTVWY.")

#Read data, remove X-containing sequences, and adding start and end tokens
data = readlines("abs.txt")
data = [">"*d*"." for d in data if !(occursin("X", d))]

#Train the model
lr = 0.001f0
opt_state = Flux.setup(AdamW(lr), model)
losses = Float32[]
for i in 1:2001
#Prep random batch
train_toks = pad_and_batch(encode.((AAs, ), data[sample(1:length(data), 10, replace=false)]), 22);
#Compute loss and gradients
loss, grads = Flux.withgradient(model) do m
forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:])
end
#Update weights
Flux.update!(opt_state, model, grads[1])
#Monitor
push!(losses, loss)
println(i, " ", loss)
#Monitor sampling
if mod(i, 100) == 1
generate(model, encode(AAs, ">"),
max_new_tokens=500,
tokenizer_for_printing=AAs,
end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22))
println()
display(lineplot(losses, width = 150, height = 30))
end
#Linear learning rate cooldown
if i > 1500
lr = max(lr - 0.001f0/(2000-1500), 0.0000001f0)
Flux.adjust!(opt_state, lr)
end
end

#Test sampling
for i in 1:10
println(">", i)
generate(model, encode(AAs, ">"),
max_new_tokens=500,
tokenizer_for_printing=AAs,
end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22))
println()
end

#Exporting the model
export_model(model, "tinyabllama.safetensors", type_convert = x -> Jjama3.SafeTensors.BFloat16.(x))

#Saving a config so that it loads correctly using the Jjama3 loader
using JSON3
config = Dict()
config[:model_type] = "llama"
config[:vocab_size]= 22
config[:hidden_size] = 16*8
config[:num_hidden_layers] = 12
config[:num_attention_heads] = 8
config[:num_key_value_heads] = 4
config[:max_position_embeddings] = 8192
config[:intermediate_size] = 16*10
config[:rms_norm_eps] = 1f-8
config[:rope_theta] = 500000f0
config[:tie_word_embeddings] = false
open("tinyabllama_config.json", "w") do f
JSON3.pretty(f, JSON3.write(config))
println(f)
end

#Load a trained model and test it
config = JSON3.read(read("tinyabllama_config.json", String))
model_weight_paths = ["tinyabllama.safetensors"]
model = load_llama3_from_safetensors(model_weight_paths, config)
@assert generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
3 changes: 3 additions & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export RoPE
export Attention
export TransformerBlock
export Transformer
export unrope

include("model.jl")
export forward_loss
Expand All @@ -49,5 +50,7 @@ export llama3_assistant_prompt
export smollm2_instruct_prompt
export smollm2_assistant_prompt
export structured_choice
export pad_and_batch
export export_model

end
2 changes: 1 addition & 1 deletion src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end

Base.copy(cache::KVCache) = KVCache(copy(cache.cache_k), copy(cache.cache_v))

Check warning on line 6 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L6

Added line #L6 was not covered by tests

Flux.@layer KVCache
Flux.@layer KVCache trainable=()

head_dim(cache::KVCache) = size(cache.cache_k, 1)
seq_length(cache::KVCache) = size(cache.cache_k, 2)
Expand Down
26 changes: 18 additions & 8 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ struct RoPE{A<:AbstractArray}
sin::A
end

Flux.@layer RoPE
Flux.@layer RoPE trainable=()

Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:])

function apply_scaling!(freqs::AbstractVector; scale_factor=8)
#Hard-coded - should move these to the main model struct and grab them from the config.
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192
###
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
for (i, freq) in enumerate(freqs)
Expand All @@ -68,15 +66,15 @@ end

function RoPE(
dim::Int, end_pos::Int;
theta::T=10000f0, use_scaled=true, scale_factor=8,
theta::T=10000f0, use_scaled=true, scale_factor=8, start_pos=0
) where T
freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
use_scaled && apply_scaling!(freqs; scale_factor)
freqs_complex = cis.(T.(0:end_pos-1) * freqs')
freqs_complex = cis.(T.(start_pos:end_pos-1) * freqs')

Check warning on line 73 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L73

Added line #L73 was not covered by tests
cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len)
sin = permutedims(imag(freqs_complex), (2, 1))
cos = reshape(cos, (dim÷2, end_pos, 1, 1))
sin = reshape(sin, (dim÷2, end_pos, 1, 1))
cos = reshape(cos, (dim÷2, end_pos - start_pos, 1, 1))
sin = reshape(sin, (dim÷2, end_pos - start_pos, 1, 1))

Check warning on line 77 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L76-L77

Added lines #L76 - L77 were not covered by tests
return RoPE(cos, sin)
end

Expand All @@ -93,6 +91,15 @@ function (rope::RoPE)(x)
)
end

function unrope(rope, x)
head_dim = size(x, 1)
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]
return vcat(

Check warning on line 98 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L94-L98

Added lines #L94 - L98 were not covered by tests
x1 .* rope.cos .+ x2 .* rope.sin,
x2 .* rope.cos .- x1 .* rope.sin
)
end

struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache}
wq::Q
Expand Down Expand Up @@ -219,10 +226,12 @@ function Transformer(
layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers)
norm = RMSNorm(dim, eps=norm_eps)
output = Dense(dim => vocab_size, bias=false)
#This should probably be generated to a sane length, and then extended in the forward pass if needed.
rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor)
Transformer(tok_embeddings, layers, norm, output, rope, 0)

Check warning on line 231 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L231

Added line #L231 was not covered by tests
end


function clear_cache!(model::Transformer)
model.pos = 0
for layer in model.layers
Expand All @@ -232,4 +241,5 @@ end

config_cache!(model::Transformer, seq_length) = for layer in model.layers config!(layer.attention.cache, seq_length = seq_length) end

Check warning on line 242 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L242

Added line #L242 was not covered by tests

extend_cache!(model::Transformer, seq_length) = for layer in model.layers extend!(layer.attention.cache, seq_length + model.pos) end
extend_cache!(model::Transformer, seq_length) = for layer in model.layers extend!(layer.attention.cache, seq_length + model.pos) end

Check warning on line 244 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L244

Added line #L244 was not covered by tests

49 changes: 49 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ using Accessors
encode(tkn::Tokenizer, str; kwargs...) = HuggingFaceTokenizers.encode(tkn, str; kwargs...).ids .+ 1
decode(tkn::Tokenizer, ids; kwargs...) = HuggingFaceTokenizers.decode(tkn, ids .- 1; kwargs...)

#Trivial Char tokenizers:
encode(chars::Vector{Char}, str::String) = [findfirst(==(c), chars) for c in str]
decode(chars::Vector{Char}, enc::Vector{Int}; skip_special_tokens=false) = String([chars[i] for i in enc])

Check warning on line 8 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L7-L8

Added lines #L7 - L8 were not covered by tests

#For training:
function pad_and_batch(seqs, pad_token)
max_len = maximum(length.(seqs))
padded = [vcat(s, fill(pad_token, max_len - length(s))) for s in seqs]
cat(padded..., dims = 2)

Check warning on line 14 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L11-L14

Added lines #L11 - L14 were not covered by tests
end


#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
function llama3_instruct_prompt(tokenizer,system_prompt, user_prompt)
Expand Down Expand Up @@ -237,3 +248,41 @@ function structured_choice(choices::Vector{String}, vocab::Vector{String}, end_t
end
return choice_sampler
end


function export_model(model, output_path; type_convert = identity)
weights = Dict{String,AbstractArray}()
weights["model.embed_tokens.weight"] = type_convert(model.tok_embeddings.weight')
weights["lm_head.weight"] = type_convert(model.output.weight)
weights["model.norm.weight"] = type_convert(model.norm.weight)
for (i, layer) in enumerate(model.layers)
prefix = "model.layers.$(i-1)"
weights["$prefix.self_attn.q_proj.weight"] = type_convert(layer.attention.wq.weight)
weights["$prefix.self_attn.k_proj.weight"] = type_convert(layer.attention.wk.weight)
weights["$prefix.self_attn.v_proj.weight"] = type_convert(layer.attention.wv.weight)
weights["$prefix.self_attn.o_proj.weight"] = type_convert(layer.attention.wo.weight)
if layer.attention.wq.bias
weights["$prefix.self_attn.q_proj.bias"] = type_convert(layer.attention.wq.bias)
weights["$prefix.self_attn.k_proj.bias"] = type_convert(layer.attention.wk.bias)
weights["$prefix.self_attn.v_proj.bias"] = type_convert(layer.attention.wv.bias)

Check warning on line 267 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L253-L267

Added lines #L253 - L267 were not covered by tests
end
weights["$prefix.mlp.gate_proj.weight"] = type_convert(layer.feed_forward.w1.weight)
weights["$prefix.mlp.down_proj.weight"] = type_convert(layer.feed_forward.w2.weight)
weights["$prefix.mlp.up_proj.weight"] = type_convert(layer.feed_forward.w3.weight)
weights["$prefix.input_layernorm.weight"] = type_convert(layer.attention_norm.weight)
weights["$prefix.post_attention_layernorm.weight"] = type_convert(layer.ffn_norm.weight)
end
SafeTensors.serialize(output_path, weights)
return nothing

Check warning on line 276 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L269-L276

Added lines #L269 - L276 were not covered by tests
end

#=
#Example for how to save a model in safetensors format
julia> using Random, BFloat16s, SafeTensors
julia> weights = Dict("W"=>randn(BFloat16, 3, 5), "b"=>rand(BFloat16, 3))
Dict{String, Array{BFloat16}} with 2 entries:
"W" => [0.617188 0.695312 … 0.390625 -2.0; -0.65625 -0.617188 … 0.652344 0.244141; 0.226562 2.70312 … -0.174805 -0.7773…
"b" => [0.111816, 0.566406, 0.283203]
julia> f = tempname();
julia> SafeTensors.serialize(f, weights)
=#

0 comments on commit f27607c

Please sign in to comment.