diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 7548b4f1..79384a7d 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -111,8 +111,12 @@ if isdefined(LinearAlgebra, :copytrito!) LinearAlgebra.BLAS.chkuplo(uplo) m,n = size(A) m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) if uplo == 'U' + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end @kernel function U_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) @@ -122,6 +126,11 @@ if isdefined(LinearAlgebra, :copytrito!) end U_kernel!(get_backend(B))(A, B; ndrange = size(A)) else # uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end @kernel function L_kernel!(_A, _B) I = @index(Global, Cartesian) i, j = Tuple(I) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 7c03e69f..8de31854 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -85,6 +85,13 @@ B = zeros(T,n,n) @test compare(copytrito!, AT, B, A, uplo) end + @testset for T in eltypes, uplo in ('L', 'U') + n = 16 + m = 32 + A = uplo == 'U' ? rand(T,m,n) : rand(T,n,m) + B = zeros(T,n,n) + @test compare(copytrito!, AT, B, A, uplo) + end end end