Skip to content

Commit

Permalink
Add generic_trimatmul and generic_mattrimul
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Apr 20, 2024
1 parent 228fb9e commit 23b3c45
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,43 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
C
end

function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
if size(A,2) != size(B,1)
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
end
if size(C,1) != size(A,1) || size(C,2) != size(B,2)
throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))"))
end
if isempty(A) || isempty(B)
return fill!(C, zero(R))
end

upper = uploc == 'U'
unit = isunitc == 'U'

# tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C'

gpu_call(C, A, B; name="trimatmul") do ctx, C, A, B
idx = @linearidx C
assume.(size(C) .> 0)
i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1

@inbounds if i <= size(A,1) && j <= size(B,2)
z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j])
Ctmp = convert(promote_type(R, typeof(z2)), z2)
Ctmp += (unit ? one(Ctmp) : A[i, i]) * B[i, j]
for k in ((upper ? i : 1) + 1):(upper ? size(A,2) : size(A,2) - i)
Ctmp += A[i, k]*B[k, j]
end
C[i,j] += Ctmp
end

return
end

C
end

function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul())
generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta)
end
Expand All @@ -380,6 +417,18 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
end

if VERSION >= v"1.10-"

function LinearAlgebra.generic_trimatmul!(C::AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUVecOrMat)
generic_trimatmul!(C, uploc, isunitc, tfun, A, B)
end

function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUMatrix)
generic_mattrimul!(C, uploc, isunitc, tfun, A, B)
end

end

if VERSION < v"1.10.0-DEV.1365"
# catch other functions that are called by LinearAlgebra's mul!
function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
Expand Down
33 changes: 33 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,39 @@
@test istriu(A) == istriu(B)
end
end

if VERSION >= v"1.10-"
@testset "trimatmul" begin
n = 128
b = AT(rand(Float32, n))
B = AT(rand(Float32, n, n))

At = UpperTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B

At = UnitUpperTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B


At = LowerTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B

At = UnitLowerTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B
end
end
end

@testset "diagonal" begin
Expand Down

0 comments on commit 23b3c45

Please sign in to comment.