diff --git a/src/host/linalg.jl b/src/host/linalg.jl index dd6d2b92..5619dd83 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -359,11 +359,11 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac @inbounds if i <= size(A,1) && j <= size(B,2) z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) - Ctmp = convert(promote_type(R, typeof(z2)), z2) + Cij = convert(promote_type(R, typeof(z2)), z2) for k in 1:size(A,2) - Ctmp += A[i, k]*B[k, j] + Cij += A[i, k]*B[k, j] end - C[i,j] = add(Ctmp, C[i,j]) + C[i,j] = add(Cij, C[i,j]) end return @@ -388,7 +388,184 @@ end function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b)) end -end +end + +function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + upper = tfun === identity ? uploc == 'U' : uploc != 'U' + unit = isunitc == 'U' + + function trimatmul(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += (unit ? one(Cij) : A[i,i]) * B[i,j] + for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) + Cij += A[i,k] * B[k,j] + end + C[i,j] += Cij + end + + return + end + + function trimatmul_t(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += (unit ? one(Cij) : transpose(A[i,i])) * B[i,j] + for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) + Cij += transpose(A[k,i]) * B[k,j] + end + C[i,j] += Cij + end + + return + end + + function trimatmul_a(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += (unit ? one(Cij) : adjoint(A[i,i])) * B[i,j] + for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) + Cij += adjoint(A[k,i]) * B[k,j] + end + C[i,j] += Cij + end + + return + end + + if tfun === identity + gpu_call(trimatmul, C, A, B; name="trimatmul") + elseif tfun == transpose + gpu_call(trimatmul_t, C, A, B; name="trimatmul_t") + elseif tfun === adjoint + gpu_call(trimatmul_a, C, A, B; name="trimatmul_a") + else + error("Not supported") + end + + C +end + +function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} + if size(A,2) != size(B,1) + throw(DimensionMismatch(lazy"matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch(lazy"result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + upper = tfun === identity ? uploc == 'U' : uploc != 'U' + unit = isunitc == 'U' + + function mattrimul(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,j] * (unit ? one(Cij) : B[j,j]) + for k in (upper ? 1 : (j + 1)):(upper ? (j - 1) : m) + Cij += A[i,k] * B[k,j] + end + C[i,j] += Cij + end + + return + end + + function mattrimul_t(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,j] * (unit ? one(Cij) : transpose(B[j,j])) + for k in (upper ? 1 : (j + 1) ):(upper ? (j - 1) : m) + Cij += A[i,k] * transpose(B[j,k]) + end + C[i,j] += Cij + end + + return + end + + function mattrimul_a(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += A[i,j] * (unit ? one(Cij) : adjoint(B[j,j])) + for k in (upper ? 1 : (j + 1)):(upper ? (j - 1) : m) + Cij += A[i,k] * adjoint(B[j,k]) + end + C[i,j] += Cij + end + + return + end + + if tfun === identity + gpu_call(mattrimul, C, A, B; name="mattrimul") + elseif tfun == transpose + gpu_call(mattrimul_t, C, A, B; name="mattrimul_t") + elseif tfun === adjoint + gpu_call(mattrimul_a, C, A, B; name="mattrimul_a") + else + error("Not supported") + end + + C +end + +if VERSION >= v"1.10-" +function LinearAlgebra.generic_trimatmul!(C::AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUVecOrMat) + generic_trimatmul!(C, uploc, isunitc, tfun, A, B) +end +function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUMatrix) + generic_mattrimul!(C, uploc, isunitc, tfun, A, B) +end +end if VERSION < v"1.10.0-DEV.1365" # catch other functions that are called by LinearAlgebra's mul! diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index da858608..d84bb5bd 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -132,6 +132,39 @@ @test istriu(A) == istriu(B) end end + + if VERSION >= v"1.10-" + @testset "mul! + Triangular" begin + @testset "trimatmul! ($TR x $T, $f)" for T in (Float32, ComplexF32), TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), f in (identity, transpose, adjoint) + n = 128 + A = AT(rand(T, n,n)) + b = AT(rand(T, n)) + Ct = AT(zeros(T, n)) + C = zeros(T, n) + mul!(Ct, f(TR(A)), b) + mul!(C, f(TR(collect(A))), collect(b)) + @test collect(Ct) ≈ C + + B = AT(rand(T, n, n)) + Ct = AT(zeros(T, n, n)) + C = zeros(T, n, n) + mul!(Ct, f(TR(A)), B) + mul!(C, f(TR(collect(A))), collect(B)) + @test collect(Ct) ≈ C + end + + @testset "mattrimul ($TR x $T, $f)" for T in (Float32, ComplexF32), TR in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), f in (identity, transpose, adjoint) + n = 128 + A = AT(rand(T, n,n)) + B = AT(rand(T, n, n)) + Ct = AT(zeros(T, n, n)) + C = zeros(T, n, n) + mul!(Ct, A, f(TR(B))) + mul!(C, collect(A), f(TR(collect(B)))) + @test collect(Ct) ≈ C + end + end + end end @testset "diagonal" begin