Skip to content

Commit

Permalink
test/runtests.jl: Fix tests, as structure has no scale parameter now.
Browse files Browse the repository at this point in the history
  • Loading branch information
mashu committed Nov 28, 2024
1 parent 56cf469 commit 8f2120a
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ end
head_dim = size(x, 1)
cos_mat = view(pe.cos_cached, 1:head_dim, 1:1, 1:seq_len, :)
sin_mat = view(pe.sin_cached, 1:head_dim, 1:1, 1:seq_len, :)
expected_output = @. muladd(x * pe.scale, cos_mat, neg_half_x * pe.scale * sin_mat)
expected_output = @. muladd(x, cos_mat, neg_half_x * sin_mat)

# Test the forward pass
actual_output = pe(x)
Expand Down Expand Up @@ -188,8 +188,7 @@ end
rope_gpu = RoPE(
rope_gpu.head_dim,
cu(rope_gpu.cos_cached),
cu(rope_gpu.sin_cached),
rope_gpu.scale
cu(rope_gpu.sin_cached)
)
x = CUDA.randn(Float32, 8, 1, 4, 1)

Expand All @@ -204,8 +203,7 @@ end
pe_gpu = RoPE(
pe.head_dim,
cu(pe.cos_cached),
cu(pe.sin_cached),
pe.scale
cu(pe.sin_cached)
)

cpu_output = pe(x)
Expand Down

0 comments on commit 8f2120a

Please sign in to comment.