Skip to content

Commit

Permalink
Llama3.1 loading fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 14, 2024
1 parent cfb71e9 commit eb8e61d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Jjama3
# Jjama3 - Hackable Llama3.1 and Llama3.2 in Julia

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/dev/)
Expand All @@ -7,10 +7,20 @@

# Quickstart

Download a Llama3 model `config.json` and safetensor weights from Huggingface. Eg. [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct). You might need access permissions for this. Note: Huggingface use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original weights you'll need to permute them.

```julia
config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String));
model = load_llama3_from_safetensors("Llama3_2_1B_instruct/model.safetensors", config);
tkn = llama3_tokenizer();
prompt = assistant_prompt("Why would anyone implement the llama3 LLM in Julia?", tkn);
ts = generate(model, prompt, max_new_tokens=500, encoder_for_printing=tkn);
```
```

# Capability

- Seems to generate reasonable text from Llama3.1 and Llama3.2 models, loaded from Huggingface safetensors.
- Sampling accelerated with KV caching, with argmax and top-p sampling supported.
- Gradients seem to work on CPU, using Flux and Zygote. Untested on GPU.
- Sampling (and forward passes) work with CUDA, where everything is much faster. Gradients untested.
- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is much slower with Metal than with CPU.
21 changes: 14 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
tkn = llama3_tokenizer()
Load the tokenizer for Llama3. This seems to work, but I have not checked if there are some different edge-cases, or missing tokens relative to the original tokenizer (besides the special tokens we hackily include).
tkn = llama3_tokenizer()
tkn.encode("What is the capital of France?")
tkn.decode([10, 2, 5, 99])
"""
llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base")

"""
Expand Down Expand Up @@ -88,19 +97,17 @@ function load_llama3_from_safetensors(paths::Vector{String}, config; T = Float32

for path in paths # Process one file at a time
weights = load_safetensors(path)

if (haskey(weights, "lm_head.weight") && (config[:tie_word_embeddings]))
error("tie_word_embeddings was true, but lm_head.weight was present.")
end
if haskey(weights, "model.embed_tokens.weight")
model.tok_embeddings.weight .= weights["model.embed_tokens.weight"]'
if config[:tie_word_embeddings]
model.output.weight .= weights["model.embed_tokens.weight"]
end
end
if !config[:tie_word_embeddings]
if haskey(weights, "lm_head.weight")
model.output.weight .= weights["lm_head.weight"]
else
error("tie_word_embeddings was true, but lm_head.weight was present.")
end
if haskey(weights, "lm_head.weight")
model.output.weight .= weights["lm_head.weight"]
end
if haskey(weights, "model.norm.weight")
model.norm.weight .= weights["model.norm.weight"]
Expand Down

0 comments on commit eb8e61d

Please sign in to comment.