Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Dec 8, 2024
1 parent f27607c commit 76bb79b
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}
LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"}

[extensions]
MetalExt = "Metal"
Expand All @@ -39,6 +38,7 @@ julia = "1.11"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

[targets]
test = ["Test"]
test = ["Test", "JSON3"]
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172

function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractFloat
Flux.Zygote.ignore() do
Flux.ChainRulesCore.ignore() do
dim, seqlen, batch = size(h)
mask = similar(h, seqlen, seqlen)
mask .= T(-Inf)
Expand Down Expand Up @@ -42,7 +42,7 @@ end
function forward_loss(model::Transformer, inputs::AbstractArray,
targets::AbstractArray; clear_cache = true, loss_mask = nothing)
if clear_cache
Flux.Zygote.ignore() do
Flux.ChainRulesCore.ignore() do
clear_cache!(model)
end
end
Expand Down
13 changes: 12 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
using Jjama3
using Test
using JSON3

@testset "Jjama3.jl" begin
# Write your tests here.

@testset "Amino Acid Model" begin
url_branch = "https://raw.githubusercontent.com/MurrellGroup/Jjama3.jl/aminoacid-model/"
config_path = Downloads.download(url_branch * "tinyabllama_config.json")
model_path = download(url_branch * "tinyabllama.safetensors")
config = JSON3.read(read(config_path, String))
model = load_llama3_from_safetensors([model_path], config)
AAs = collect(">ACDEFGHIKLMNPQRSTVWY.")
@test generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
end

end

0 comments on commit 76bb79b

Please sign in to comment.