Skip to content

Commit

Permalink
Adapt to new LinearAlgebra.generic_*mul! interface
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Feb 28, 2024
1 parent f3181d4 commit 1c76587
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,10 @@ end


## matrix multiplication

function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R}
# legacy method
generic_matmatmul!(C::AbstractArray, A::AbstractArray, B::AbstractArray, a::Number, b::Number) where {T,S,R} =
generic_matmatmul!(C, A, B, MulAddMul(a, b))
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, add::MulAddMul) where {T,S,R}
if size(A,2) != size(B,1)
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
end
Expand All @@ -350,8 +352,6 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
return fill!(C, zero(R))
end

add = MulAddMul(a, b)

gpu_call(C, A, B; name="matmatmul!") do ctx, C, A, B
idx = @linearidx C
assume.(size(C) .> 0)
Expand All @@ -372,42 +372,52 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
C
end

if VERSION < v"1.12.0-"
function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul())
generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta)
generic_matmatmul!(C, wrap(A, tA), B, _add)
end

function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul=MulAddMul())
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end
else
function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b))
end

function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number)
LinearAlgebra.@stable_muladdmul generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(a, b))
end
end

if VERSION < v"1.10.0-DEV.1365"
# catch other functions that are called by LinearAlgebra's mul!
function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
generic_matmatmul!(C, wrap(A, tA), B, a, b)
generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b))
end
# disambiguation
function LinearAlgebra.gemv!(C::AbstractGPUVector{T}, tA::AbstractChar, A::AbstractGPUMatrix{T}, B::AbstractGPUVector{T}, a::Number, b::Number) where {T<:LinearAlgebra.BlasFloat}
generic_matmatmul!(C, wrap(A, tA), B, a, b)
generic_matmatmul!(C, wrap(A, tA), B, MulAddMul(a, b))
end

LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add)
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
# disambiguation
LinearAlgebra.gemm_wrapper!(C::AbstractGPUVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, _add::MulAddMul) where {T<:LinearAlgebra.BlasFloat} =
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add)
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)

function LinearAlgebra.syrk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())
if tA == 'T'
LinearAlgebra.generic_matmatmul!(C, 'T', 'N', A, A, _add)
generic_matmatmul!(C, wrap(A, 'T'), A, _add)
else # tA == 'N'
LinearAlgebra.generic_matmatmul!(C, 'N', 'T', A, A, _add)
generic_matmatmul!(C, A, wrap(A, 'T'), _add)
end
end
function LinearAlgebra.herk_wrapper!(C::AbstractGPUMatrix, tA::AbstractChar, A::AbstractGPUVecOrMat, _add::MulAddMul = MulAddMul())
if tA == 'C'
LinearAlgebra.generic_matmatmul!(C, 'C', 'N', A, A, _add)
generic_matmatmul!(C, wrap(A, 'C'), A, _add)
else # tA == 'N'
LinearAlgebra.generic_matmatmul!(C, 'N', 'C', A, A, _add)
generic_matmatmul!(C, A, wrap(A, 'C'), _add)
end
end
end # VERSION
Expand Down

0 comments on commit 1c76587

Please sign in to comment.