Skip to content

Commit

Permalink
Even more CuArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt committed Jan 30, 2025
1 parent ff453f2 commit 9af84b3
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions test/libraries/cublas/level3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,14 @@ k = 13
end
@testset "syr2k" begin
alpha = rand(elty)
dalpha = CuArray{elty}([alpha])
A = rand(elty,m,k)
B = rand(elty,m,k)
# move to device
d_A = CuArray(A)
d_B = CuArray(B)
C = alpha*(A*transpose(B) + B*transpose(A))
d_C = CUBLAS.syr2k('U','N',alpha,d_A,d_B)
d_C = CUBLAS.syr2k('U','N',dalpha,d_A,d_B)
# move back to host and compare
C = triu(C)
h_C = Array(d_C)
Expand All @@ -455,12 +456,14 @@ k = 13
if elty <: Complex
@testset "herk!" begin
alpha = rand(elty)
dalpha = CuArray{real(elty)}([real(alpha)])
beta = rand(elty)
dbeta = CuArray{real(elty)}([real(beta)])
A = rand(elty,m,m)
hA = A + A'
d_A = CuArray(A)
d_C = CuArray(hA)
CUBLAS.herk!('U','N',real(alpha),d_A,real(beta),d_C)
CUBLAS.herk!('U','N',dalpha,d_A,dbeta,d_C)
C = real(alpha)*(A*A') + real(beta)*hA
C = triu(C)
# move to host and compare
Expand All @@ -484,7 +487,9 @@ k = 13
elty2 = real(elty)
# generate parameters
α = rand(elty1)
= CuArray{elty2}(α)
β = rand(elty2)
= CuArray{elty2}(β)
A = rand(elty,m,k)
B = rand(elty,m,k)
Bbad = rand(elty,m+1,k+1)
Expand All @@ -496,7 +501,7 @@ k = 13
C = C + C'
d_C = CuArray(C)
C = α*(A*B') + conj(α)*(B*A') + β*C
CUBLAS.her2k!('U','N',α,d_A,d_B,β,d_C)
CUBLAS.her2k!('U','N',,d_A,d_B,,d_C)
# move back to host and compare
C = triu(C)
h_C = Array(d_C)
Expand Down

0 comments on commit 9af84b3

Please sign in to comment.