From 76bb79b7ba7c82975075ab5a839e417f0f9cf535 Mon Sep 17 00:00:00 2001 From: Anton Oresten Date: Sun, 8 Dec 2024 22:16:11 +0100 Subject: [PATCH] Add test --- Project.toml | 4 ++-- src/model.jl | 4 ++-- test/runtests.jl | 13 ++++++++++++- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 9ef394d..b4e9eab 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,6 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [sources] HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"} -LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"} [extensions] MetalExt = "Metal" @@ -39,6 +38,7 @@ julia = "1.11" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" [targets] -test = ["Test"] +test = ["Test", "JSON3"] diff --git a/src/model.jl b/src/model.jl index 05c2cf5..ea95d52 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,7 +1,7 @@ #Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172 function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractFloat - Flux.Zygote.ignore() do + Flux.ChainRulesCore.ignore() do dim, seqlen, batch = size(h) mask = similar(h, seqlen, seqlen) mask .= T(-Inf) @@ -42,7 +42,7 @@ end function forward_loss(model::Transformer, inputs::AbstractArray, targets::AbstractArray; clear_cache = true, loss_mask = nothing) if clear_cache - Flux.Zygote.ignore() do + Flux.ChainRulesCore.ignore() do clear_cache!(model) end end diff --git a/test/runtests.jl b/test/runtests.jl index 4ce44a1..476864f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,17 @@ using Jjama3 using Test +using JSON3 @testset "Jjama3.jl" begin - # Write your tests here. + + @testset "Amino Acid Model" begin + url_branch = "https://raw.githubusercontent.com/MurrellGroup/Jjama3.jl/aminoacid-model/" + config_path = Downloads.download(url_branch * "tinyabllama_config.json") + model_path = download(url_branch * "tinyabllama.safetensors") + config = JSON3.read(read(config_path, String)) + model = load_llama3_from_safetensors([model_path], config) + AAs = collect(">ACDEFGHIKLMNPQRSTVWY.") + @test generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16] + end + end