From c1d2db560be7339728d1c2371330ad754adb7a14 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Wed, 29 May 2024 20:53:03 -0400 Subject: [PATCH] Accomodate for rectangular matrices in copytrito! --- src/host/linalg.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index dd6d2b92..f7176e9f 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -133,8 +133,12 @@ if isdefined(LinearAlgebra, :copytrito!) LinearAlgebra.BLAS.chkuplo(uplo) m,n = size(A) m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) if uplo == 'U' + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end gpu_call(A, B) do ctx, _A, _B I = @cartesianidx _A i, j = Tuple(I) @@ -144,6 +148,11 @@ if isdefined(LinearAlgebra, :copytrito!) return end else # uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least size ($m,$n)")) + end gpu_call(A, B) do ctx, _A, _B I = @cartesianidx _A i, j = Tuple(I)