Skip to content

Commit

Permalink
Docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
murrellb committed Nov 14, 2024
1 parent 5e790b1 commit 7e78f19
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ top_pk_sampler(;p = 0.5f0, k = 5, device = identity) = logits -> top_pk_sampler(
# to the sampling of each token. But when I try and fix it, the model gets slightly dumber.
# Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time.
"""
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), encoder_for_printing=tkn, end_token=128010)
Takes an initial sequence of tokens, and generates new tokens one at a time until the end token is sampled. Uses a KV cache. No batch dim for now.
Runs on CPU by default. If the model is on the GPU (assuming Flux.jl, eg. `model = gpu(model)`), then pass `device = gpu` to `generate` to run on the GPU.
Expand Down
6 changes: 6 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
llama3_tokenizer() = BytePairEncoding.load_tiktoken_encoder("cl100k_base")

"""
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
Format a prompt for use with Llama3.2's instruction format, with a simple "You are a helpful assistant" system prompt.
tkn = llama3_tokenizer()
Expand All @@ -12,6 +14,8 @@ assistant_prompt(prompt, tkn) = format_llama32_instruction_prompt("\nYou are a h

#https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2/
"""
generate(model, prompt, max_new_tokens=100, encoder_for_printing=tkn)
Format a prompt for use with Llama3.2's instruction format, injecting the system and user roles.
tkn = llama3_tokenizer()
Expand Down Expand Up @@ -49,6 +53,8 @@ special_tokens = Dict(


"""
model = load_llama3_from_safetensors(model_weight_paths, config)
Load a Llama3 model from a set of Huggingface safetensors files, and the config.json file.
Important note: Huggingface uses a different RoPE convention than other implementations,
so if you're loading weights from a different source, you might get very poor model performance.
Expand Down

0 comments on commit 7e78f19

Please sign in to comment.