Skip to content

Commit

Permalink
Accomodate for rectangular matrices in copytrito! (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Oct 30, 2024
1 parent b97643c commit 43241ea
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,12 @@ if isdefined(LinearAlgebra, :copytrito!)
LinearAlgebra.BLAS.chkuplo(uplo)
m,n = size(A)
m1,n1 = size(B)
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
if uplo == 'U'
if n < m
(m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
end
@kernel function U_kernel!(_A, _B)
I = @index(Global, Cartesian)
i, j = Tuple(I)
Expand All @@ -122,6 +126,11 @@ if isdefined(LinearAlgebra, :copytrito!)
end
U_kernel!(get_backend(B))(A, B; ndrange = size(A))
else # uplo == 'L'
if m < n
(m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
end
@kernel function L_kernel!(_A, _B)
I = @index(Global, Cartesian)
i, j = Tuple(I)
Expand Down
7 changes: 7 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@
B = zeros(T,n,n)
@test compare(copytrito!, AT, B, A, uplo)
end
@testset for T in eltypes, uplo in ('L', 'U')
n = 16
m = 32
A = uplo == 'U' ? rand(T,m,n) : rand(T,n,m)
B = zeros(T,n,n)
@test compare(copytrito!, AT, B, A, uplo)
end
end
end

Expand Down

0 comments on commit 43241ea

Please sign in to comment.