From cea22f97269d5bf9bac4160ace7a143513b6fafd Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 25 Oct 2023 10:55:03 -0400 Subject: [PATCH] [NDTensors] Avoid more scalar indexing operations in block sparse GPU code (#1217) --- .../NDTensorsMetalExt/NDTensorsMetalExt.jl | 4 +++- NDTensors/ext/NDTensorsMetalExt/append.jl | 5 ++++ NDTensors/ext/NDTensorsMetalExt/mul.jl | 15 ++++++++++++ NDTensors/src/NDTensors.jl | 1 + NDTensors/src/abstractarray/append.jl | 10 ++++++++ .../src/blocksparse/blocksparsetensor.jl | 7 ++++-- NDTensors/src/blocksparse/linearalgebra.jl | 24 ++++++++++--------- NDTensors/src/dense/tensoralgebra/contract.jl | 8 +++---- NDTensors/src/linearalgebra/linearalgebra.jl | 3 ++- NDTensors/src/truncate.jl | 2 +- 10 files changed, 59 insertions(+), 20 deletions(-) create mode 100644 NDTensors/ext/NDTensorsMetalExt/append.jl create mode 100644 NDTensors/ext/NDTensorsMetalExt/mul.jl create mode 100644 NDTensors/src/abstractarray/append.jl diff --git a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl index 64b0e4477e..89ee4d42a3 100644 --- a/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl +++ b/NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl @@ -2,7 +2,7 @@ module NDTensorsMetalExt using Adapt using Functors -using LinearAlgebra: LinearAlgebra +using LinearAlgebra: LinearAlgebra, Transpose, mul! using NDTensors using NDTensors.SetParameters @@ -18,5 +18,7 @@ include("set_types.jl") include("indexing.jl") include("linearalgebra.jl") include("copyto.jl") +include("append.jl") include("permutedims.jl") +include("mul.jl") end diff --git a/NDTensors/ext/NDTensorsMetalExt/append.jl b/NDTensors/ext/NDTensorsMetalExt/append.jl new file mode 100644 index 0000000000..7487004f2b --- /dev/null +++ b/NDTensors/ext/NDTensorsMetalExt/append.jl @@ -0,0 +1,5 @@ +# This circumvents an issues that `MtlArray` can't call `resize!`. +# TODO: Raise an issue with Metal.jl. +function NDTensors.append!!(::Type{<:MtlArray}, collection, collections...) + return vcat(collection, collections...) +end diff --git a/NDTensors/ext/NDTensorsMetalExt/mul.jl b/NDTensors/ext/NDTensorsMetalExt/mul.jl new file mode 100644 index 0000000000..6a366f972a --- /dev/null +++ b/NDTensors/ext/NDTensorsMetalExt/mul.jl @@ -0,0 +1,15 @@ +# This was calling generic matrix multiplication. +# TODO: Raise an issue with `Metal.jl`. +function NDTensors.mul!!( + ::Type{<:MtlArray}, + CM::Transpose, + ::Type{<:MtlArray}, + AM::AbstractMatrix, + ::Type{<:MtlArray}, + BM::AbstractMatrix, + α, + β, +) + mul!(transpose(CM), transpose(BM), transpose(AM), α, β) + return CM +end diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index 29501e2cac..ec53a32515 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -53,6 +53,7 @@ include("abstractarray/iscu.jl") include("abstractarray/similar.jl") include("abstractarray/ndims.jl") include("abstractarray/copyto.jl") +include("abstractarray/append.jl") include("abstractarray/permutedims.jl") include("abstractarray/fill.jl") include("abstractarray/mul.jl") diff --git a/NDTensors/src/abstractarray/append.jl b/NDTensors/src/abstractarray/append.jl new file mode 100644 index 0000000000..c842914051 --- /dev/null +++ b/NDTensors/src/abstractarray/append.jl @@ -0,0 +1,10 @@ +# NDTensors.append! +# Used to circumvent issues with some GPU backends like Metal +# not supporting `resize!`. +function append!!(collection, collections...) + return append!!(leaf_parenttype(collection), collection, collections...) +end + +function append!!(::Type, collection, collections...) + return append!(collection, collections...) +end diff --git a/NDTensors/src/blocksparse/blocksparsetensor.jl b/NDTensors/src/blocksparse/blocksparsetensor.jl index a591038512..14bce7d32d 100644 --- a/NDTensors/src/blocksparse/blocksparsetensor.jl +++ b/NDTensors/src/blocksparse/blocksparsetensor.jl @@ -653,7 +653,10 @@ function uncombine( #copyto!(Rb,Tb) if length(Tb) == 1 - Rb[] = Tb[] + # Call `cpu` to avoid allowscalar error on GPU. + # TODO: Replace with `@allowscalar`, requires adding + # `GPUArraysCore.jl` as a dependency. + Rb[] = cpu(Tb)[] else # XXX: this used to be: # Rbₐᵣ = ReshapedArray(parent(Rbₐ), size(Tb), ()) @@ -712,7 +715,7 @@ function permutedims!!( ## copyto!(data(RR), data(R)) if new_nnz > nnz(RR) - dataRR = append!(data(RR), zeros(new_nnz - nnz(RR))) + dataRR = append!!(data(RR), generic_zeros(leaf_parenttype(R), new_nnz - nnz(RR))) RR = Tensor(BlockSparse(dataRR, bofsRR), inds(RR)) end diff --git a/NDTensors/src/blocksparse/linearalgebra.jl b/NDTensors/src/blocksparse/linearalgebra.jl index d4a64c9085..931f8b6011 100644 --- a/NDTensors/src/blocksparse/linearalgebra.jl +++ b/NDTensors/src/blocksparse/linearalgebra.jl @@ -6,6 +6,8 @@ const DiagMatrix{ElT,StoreT,IndsT} = DiagTensor{ElT,2,StoreT,IndsT} function _truncated_blockdim( S::DiagMatrix, docut::Real; singular_values=false, truncate=true, min_blockdim=0 ) + # TODO: Replace `cpu` with `leaf_parenttype` dispatch. + S = cpu(S) full_dim = diaglength(S) !truncate && return full_dim min_blockdim = min(min_blockdim, full_dim) @@ -84,7 +86,8 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT} if blockdim == 0 push!(dropblocks, n) else - Strunc = tensor(Diag(storage(Ss[n])[1:blockdim]), (blockdim, blockdim)) + # TODO: Replace call to `data` with `diagview`. + Strunc = tensor(Diag(data(Ss[n])[1:blockdim]), (blockdim, blockdim)) Us[n] = Us[n][1:dim(Us[n], 1), 1:blockdim] Ss[n] = Strunc Vs[n] = Vs[n][1:dim(Vs[n], 1), 1:blockdim] @@ -177,9 +180,8 @@ function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT} copyto!(blockview(U, blockU), Ub) blockviewS = blockview(S, blockS) - for i in 1:diaglength(Sb) - setdiagindex!(blockviewS, getdiagindex(Sb, i), i) - end + # TODO: Replace `data` with `diagview`. + copyto!(data(blockviewS), data(Sb)) # sV = left_arrow_sign(vind, blockV[2]) @@ -243,7 +245,8 @@ function eigen( if blockdim == 0 push!(dropblocks, n) else - Dtrunc = tensor(Diag(storage(Ds[n])[1:blockdim]), (blockdim, blockdim)) + # TODO: Replace call to `data` with `diagview`. + Dtrunc = tensor(Diag(data(Ds[n])[1:blockdim]), (blockdim, blockdim)) Ds[n] = Dtrunc new_size = (dim(Vs[n], 1), blockdim) new_data = array(Vs[n])[1:new_size[1], 1:new_size[2]] @@ -311,12 +314,11 @@ function eigen( blockD = nzblocksD[n] blockviewD = blockview(D, blockD) - for i in 1:diaglength(Db) - setdiagindex!(blockviewD, getdiagindex(Db, i), i) - end + # TODO: Replace `data` with `diagview`. + copyto!(data(blockviewD), data(Db)) blockV = nzblocksV[n] - blockview(V, blockV) .= Vb + copyto!(blockview(V, blockV), Vb) end return D, V, Spectrum(d, truncerr) @@ -380,8 +382,8 @@ function qx(qx::Function, T::BlockSparseTensor{<:Any,2}; kwargs...) X = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksX, indsX) for n in 1:nnzblocksT - blockview(Q, nzblocksQ[n]) .= Qs[n] - blockview(X, nzblocksX[n]) .= Xs[n] + copyto!(blockview(Q, nzblocksQ[n]), Qs[n]) + copyto!(blockview(X, nzblocksX[n]), Xs[n]) end Q = adapt(leaf_parenttype(T), Q) diff --git a/NDTensors/src/dense/tensoralgebra/contract.jl b/NDTensors/src/dense/tensoralgebra/contract.jl index e1d99ce2eb..5171507ce6 100644 --- a/NDTensors/src/dense/tensoralgebra/contract.jl +++ b/NDTensors/src/dense/tensoralgebra/contract.jl @@ -15,11 +15,11 @@ function _contract_scalar!( β=zero(ElR), ) where {ElR} if iszero(β) - R[1] = α * T1 * T2 + R[] = α * T1 * T2 elseif iszero(α) - R[1] = β * R[1] + R[] = β * R[] else - R[1] = α * T1 * T2 + β * R[1] + R[] = α * T1 * T2 + β * R[] end return R end @@ -150,7 +150,7 @@ function _contract_scalar!( β=zero(ElR), ) where {ElR} if nnz(T1) == nnz(T2) == 1 - _contract_scalar!(R, labelsR, T1[1], labelsT1, T2[1], labelsT2, α, β) + _contract_scalar!(R, labelsR, T1[], labelsT1, T2[], labelsT2, α, β) else _contract_scalar_maybe_perm!(R, labelsR, T1, labelsT1, T2, labelsT2, α, β) end diff --git a/NDTensors/src/linearalgebra/linearalgebra.jl b/NDTensors/src/linearalgebra/linearalgebra.jl index cab3f7e057..20dec6b3a9 100644 --- a/NDTensors/src/linearalgebra/linearalgebra.jl +++ b/NDTensors/src/linearalgebra/linearalgebra.jl @@ -231,7 +231,8 @@ function eigen( DM, VM = eigen(matrixT) # Sort by largest to smallest eigenvalues - p = sortperm(DM; rev=true, by=abs) + # TODO: Replace `cpu` with `leaf_parenttype` dispatch. + p = sortperm(cpu(DM); rev=true, by=abs) DM = DM[p] VM = VM[:, p] diff --git a/NDTensors/src/truncate.jl b/NDTensors/src/truncate.jl index b92fa8a1f1..829c351127 100644 --- a/NDTensors/src/truncate.jl +++ b/NDTensors/src/truncate.jl @@ -11,7 +11,7 @@ end # GPU fallback version, convert to CPU. function truncate!!(::Type{<:AbstractArray}, P::AbstractArray; kwargs...) P_cpu = cpu(P) - P_cpu, truncerr, docut = truncate!(P_cpu; kwargs...) + truncerr, docut = truncate!(P_cpu; kwargs...) P = adapt(leaf_parenttype(P), P_cpu) return P, truncerr, docut end