Skip to content

Commit

Permalink
Merge pull request #9 from MurrellGroup/lora
Browse files Browse the repository at this point in the history
Readme tweaks
  • Loading branch information
murrellb authored Nov 26, 2024
2 parents 509b162 + b563a47 commit 15a3617
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
## Installation


We've split this into a few (unregistered) packages, so you'll need to add them all:
We've split this into a few (unregistered) packages, so you'll need to add them all, and you need JSON3 for loading the configs:
```julia
] add JSON3
] add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl
] add https://github.com/MurrellGroup/LowRankLayers.jl
] add https://github.com/MurrellGroup/LogitSamplers.jl
Expand All @@ -20,12 +21,17 @@ We've split this into a few (unregistered) packages, so you'll need to add them
Download a Llama3 model `config.json`, `tokenizer.json`, and model safetensor weights from Hugging Face. Eg. [SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct/tree/main). Note: Hugging Face Llama3 models use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original Meta-Llama weights from a different source you'll need to do something horrible.

```julia
using JSON3, Jjama3

config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String))
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config)
tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json")

prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.");
ts = generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]);
prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.")
generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = encode(tkn, "<|im_end|>")[end]);
```

## Capability
Expand All @@ -43,7 +49,12 @@ The transformer emits "logits" which control the probability of the next token.

```julia
prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.");
ts = generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end], sampler = top_nσ_sampler());

generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = encode(tkn, "<|im_end|>")[end],
sampler = top_nσ_sampler());
```

## Structured Sampling
Expand All @@ -52,11 +63,23 @@ You can pass in a custom sampler that places additional constraints on the sampl

```julia
question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?"
choices = ["Prior", "Likelihood", "Marginal Likelihood", "Evidence", "Posterior", "Margin Call"]
choices = ["Prior",
"Likelihood",
"Marginal Likelihood",
"Evidence",
"Posterior"]

vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:size(model.output.weight,1)]
eos = encode(tkn, "<|im_end|>")[end]
prompt = smollm2_instruct_prompt(tkn, "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible.",question)
ts = generate(model, prompt, max_new_tokens=100, tokenizer_for_printing=tkn, end_token = eos, sampler = structured_choice(choices, vocab, eos));

sysprompt = "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible."
prompt = smollm2_instruct_prompt(tkn, sysprompt, question)

generate(model, prompt,
max_new_tokens=100,
tokenizer_for_printing=tkn,
end_token = eos,
sampler = structured_choice(choices, vocab, eos));
```

This strategy can be extended to force the model outputs to follow specific formats.
Expand All @@ -67,34 +90,47 @@ Often we want to adjust model parameters to better fit our specific use case, by

```julia
using Jjama3, JSON3, Flux

config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String))
tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json")
eos = encode(tkn, "<|im_end|>")[end]

#Add LoRA to Q and V matrices when loading the model
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config, add_lora_to = [:Q, :V], lora_dim = 64)
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config,
add_lora_to = [:Q, :V], lora_dim = 64)

#Set up a single, very silly, training example to finetune on
#See how the model answers before finetuning
prompt = smollm2_assistant_prompt(tkn, "What language is the best for deep learning?");
ts = generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end]);
trainsample = decode(tkn,prompt, skip_special_tokens = false) * "Ugh, bruh, what a stupid question.<|im_end|>";
generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = eos);

#Set up a single, very silly, training example to finetune on
ugh = "Ugh, bruh, what a stupid question.<|im_end|>"
trainsample = decode(tkn, prompt, skip_special_tokens = false) * ugh;
train_toks = encode(tkn, trainsample);

#Set up the optimizer
opt_state = Flux.setup(AdamW(0.001f0), model);

#Train for 5 steps
#Train for 5 steps, monitoring the model's output as it tunes
for i in 1:5
grads = Flux.gradient(model) do m
forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:])
end
Flux.update!(opt_state, model, grads[1])
println(i)
generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end])
generate(model, prompt,
max_new_tokens=50,
tokenizer_for_printing=tkn,
end_token = eos)
println()
end

#Ask the model an unrelated question:
prompt = smollm2_assistant_prompt(tkn, "Can you explain how tides work?");
generate(model, prompt, max_new_tokens=500, tokenizer_for_printing=tkn, end_token = encode(tkn, "<|im_end|>")[end], sampler = top_nσ_sampler());
#Ask the model an unrelated question to see how stupid we've made the model. Try this a few times.
prompt = smollm2_assistant_prompt(tkn, "Explain how tides work?");
generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = eos,
sampler = top_nσ_sampler());
```

0 comments on commit 15a3617

Please sign in to comment.