Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Dec 9, 2024
1 parent 5c26f2a commit 0ad7210
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources.HuggingFaceTokenizers]
rev = "main"
url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"
[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}

[extensions]
MetalExt = "Metal"
Expand All @@ -39,7 +38,8 @@ julia = "1.11"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

[targets]
test = ["Test", "JSON3"]
test = ["Test", "Downloads", "JSON3"]
4 changes: 2 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172

function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractFloat
Flux.ChainRulesCore.ignore() do
Flux.ChainRulesCore.ignore_derivatives() do
dim, seqlen, batch = size(h)
mask = similar(h, seqlen, seqlen)
mask .= T(-Inf)
Expand Down Expand Up @@ -42,7 +42,7 @@ end
function forward_loss(model::Transformer, inputs::AbstractArray,
targets::AbstractArray; clear_cache = true, loss_mask = nothing)
if clear_cache
Flux.ChainRulesCore.ignore() do
Flux.ChainRulesCore.ignore_derivatives() do
clear_cache!(model)

Check warning on line 46 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L44-L46

Added lines #L44 - L46 were not covered by tests
end
end
Expand Down
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
using Jjama3
using Test
using JSON3
using Downloads

@testset "Jjama3.jl" begin

@testset "Amino Acid Model" begin
url_branch = "https://raw.githubusercontent.com/MurrellGroup/Jjama3.jl/aminoacid-model/"
config_path = download(url_branch * "tinyabllama_config.json")
model_path = download(url_branch * "tinyabllama.safetensors")
config_path = Downloads.download(url_branch * "tinyabllama_config.json")
model_path = Downloads.download(url_branch * "tinyabllama.safetensors")
config = JSON3.read(read(config_path, String))
model = load_llama3_from_safetensors([model_path], config)
AAs = collect(">ACDEFGHIKLMNPQRSTVWY.")
Expand Down

0 comments on commit 0ad7210

Please sign in to comment.