Skip to content

Commit

Permalink
Accomodate for rectangular matrices in copytrito!
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed May 30, 2024
1 parent 1b3741c commit c1d2db5
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,12 @@ if isdefined(LinearAlgebra, :copytrito!)
LinearAlgebra.BLAS.chkuplo(uplo)
m,n = size(A)
m1,n1 = size(B)
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
if uplo == 'U'
if n < m
(m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
end
gpu_call(A, B) do ctx, _A, _B
I = @cartesianidx _A
i, j = Tuple(I)
Expand All @@ -144,6 +148,11 @@ if isdefined(LinearAlgebra, :copytrito!)
return
end
else # uplo == 'L'
if m < n
(m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)"))
else
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)"))
end
gpu_call(A, B) do ctx, _A, _B
I = @cartesianidx _A
i, j = Tuple(I)
Expand Down

0 comments on commit c1d2db5

Please sign in to comment.