Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reactant compatibility #15

Merged
merged 8 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Jjama3"
uuid = "1285d783-1a6d-4703-8f05-8ac83ef55592"
authors = ["murrellb <[email protected]> and contributors"]
version = "1.0.0-DEV"
version = "1.1.0-DEV"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
Expand All @@ -19,13 +20,13 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}
LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"}
LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"}

[extensions]
MetalExt = "Metal"

[compat]
Accessors = "0.1.38"
Distributions = "0.25"
Flux = "0.14"
LogitSamplers = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ generate(model, prompt,
- RoPE scaling (for exceeding the model's max training-time context length) is implemented, but likely incorrect with KV cache. Be careful if you're using with really long sequences.
- Imported models are trainable (with Flux), including with low-rank (ie. LoRA) finetuning.
- Sampling, training, etc compatible with CUDA, where everything is much faster.
- Metal acceleration for forward_inference, forward_loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU.
- Metal acceleration for forward inference, forward loss, and sampling. Gradients (with Zygote) fail. Sampling works, but is slower with Metal than with CPU.


## Samplers
Expand Down
6 changes: 4 additions & 2 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module MetalExt

#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling.
#It seems that each single Metal call has some constant overhead that kills it.
# See https://github.com/FluxML/NNlib.jl/pull/614

# Note: Metal speeds things up a little for forward inference and forward_loss calls, but is VERY slow for sampling.
# It seems that each single Metal call has some constant overhead that kills it.

using Metal, NNlib

Expand Down
69 changes: 39 additions & 30 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,53 @@
module Jjama3

using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib
using LogitSamplers, LowRankLayers
import HuggingFaceTokenizers
using Flux
using SafeTensors
using Distributions
using LinearAlgebra
using StatsBase
using NNlib
using LogitSamplers
using LowRankLayers

using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer

const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained
const tokenizer_from_file = HuggingFaceTokenizers.from_file
const Tokenizer = HuggingFaceTokenizers.Tokenizer

const top_pk_sampler = LogitSamplers.top_pk_sampler
const argmax_sampler = LogitSamplers.argmax_sampler
const min_p_sampler = LogitSamplers.min_p_sampler
const top_nσ_sampler = LogitSamplers.top_nσ_sampler


include("cache.jl")
export KVCache

include("layers.jl")
export FeedForward
export RMSNorm
export RoPE
export Attention
export TransformerBlock
export Transformer

include("model.jl")
include("utils.jl")
export forward_loss
export forward_inference

include("sampling.jl")
export top_pk_sampler
export argmax_sampler
export top_nσ_sampler
export min_p_sampler
export generate
export tokenizer_from_repo
export tokenizer_from_file
export Tokenizer

export load_llama321B_from_safetensors,
load_llama3_from_safetensors,
generate,
forward_loss,
forward_inference,
top_pk_sampler,
argmax_sampler,
top_nσ_sampler,
min_p_sampler,
tokenizer_from_repo,
tokenizer_from_file,
encode,
decode,
Tokenizer,
llama3_instruct_prompt,
llama3_assistant_prompt,
smollm2_instruct_prompt,
smollm2_assistant_prompt,
structured_choice
include("utils.jl")
export encode
export decode
export load_llama321B_from_safetensors
export load_llama3_from_safetensors
export llama3_instruct_prompt
export llama3_assistant_prompt
export smollm2_instruct_prompt
export smollm2_assistant_prompt
export structured_choice

end
37 changes: 37 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
mutable struct KVCache{T,A<:AbstractArray{T,4}}
cache_k::A
cache_v::A
end

Flux.@layer KVCache

head_dim(cache::KVCache) = size(cache.cache_k, 1)
seq_length(cache::KVCache) = size(cache.cache_k, 2)
n_kv_heads(cache::KVCache) = size(cache.cache_k, 3)
batch_size(cache::KVCache) = size(cache.cache_k, 4)

Check warning on line 11 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L8-L11

Added lines #L8 - L11 were not covered by tests

function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1)
cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size)
cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size)
return KVCache(cache_k, cache_v)

Check warning on line 16 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L13-L16

Added lines #L13 - L16 were not covered by tests
end

function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_size(cache))
cache.cache_k = similar(cache.cache_k, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0
cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0

Check warning on line 21 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L19-L21

Added lines #L19 - L21 were not covered by tests
end

clear!(cache::KVCache) = config!(cache, seq_length=0)

Check warning on line 24 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L24

Added line #L24 was not covered by tests

function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray)
if iszero(seq_length(cache))
println("fuck")
return xk, xv

Check warning on line 29 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L26-L29

Added lines #L26 - L29 were not covered by tests
else
seqlen = size(xk, 2)
cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk
cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv
return cache.cache_k[:, 1:start_pos+seqlen, :, :],

Check warning on line 34 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L31-L34

Added lines #L31 - L34 were not covered by tests
cache.cache_v[:, 1:start_pos+seqlen, :, :]
end
end
Loading
Loading