Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NDTensors] Get more block sparse operations working on GPU #1215

Merged
merged 12 commits into from
Oct 24, 2023
1 change: 1 addition & 0 deletions NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ include("imports.jl")
include("set_types.jl")
include("iscu.jl")
include("adapt.jl")
include("indexing.jl")
include("linearalgebra.jl")
end
8 changes: 8 additions & 0 deletions NDTensors/ext/NDTensorsCUDAExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function Base.getindex(::Type{<:CuArray}, T::DenseTensor{<:Number})
return CUDA.@allowscalar data(T)[]
end

function Base.setindex!(::Type{<:CuArray}, T::DenseTensor{<:Number}, x::Number)
CUDA.@allowscalar data(T)[] = x
return T
end
8 changes: 6 additions & 2 deletions NDTensors/ext/NDTensorsMetalExt/NDTensorsMetalExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module NDTensorsMetalExt

using Adapt
using Functors
using LinearAlgebra: LinearAlgebra
using NDTensors
using NDTensors.SetParameters
using Functors
using Adapt

if isdefined(Base, :get_extension)
using Metal
Expand All @@ -14,4 +15,7 @@ end
include("imports.jl")
include("adapt.jl")
include("set_types.jl")
include("indexing.jl")
include("linearalgebra.jl")
include("permutedims.jl")
end
8 changes: 8 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
function Base.getindex(::Type{<:MtlArray}, T::DenseTensor{<:Number})
return Metal.@allowscalar data(T)[]
end

function Base.setindex!(::Type{<:MtlArray}, T::DenseTensor{<:Number}, x::Number)
Metal.@allowscalar data(T)[] = x
return T
end
16 changes: 16 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function NDTensors.qr(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
Q, R = NDTensors.qr(NDTensors.cpu(A))
return adapt(leaf_parenttype, Matrix(Q)), adapt(leaf_parenttype, R)
end

function NDTensors.eigen(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
D, U = NDTensors.eigen(NDTensors.cpu(A))
return adapt(set_ndims(leaf_parenttype, ndims(D)), D), adapt(leaf_parenttype, U)
end

function NDTensors.svd(leaf_parenttype::Type{<:MtlArray}, A::AbstractMatrix)
U, S, V = NDTensors.svd(NDTensors.cpu(A))
return adapt(leaf_parenttype, U),
adapt(set_ndims(leaf_parenttype, ndims(S)), S),
adapt(leaf_parenttype, V)
end
12 changes: 12 additions & 0 deletions NDTensors/ext/NDTensorsMetalExt/permutedims.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function NDTensors.permutedims!(
::Type{<:MtlArray},
Adest::Base.ReshapedArray{<:Any,<:Any,<:SubArray},
::Type{<:MtlArray},
A,
perm,
)
Aperm = permutedims(A, perm)
Adest_parent = parent(Adest)
copyto!(Adest_parent, Aperm)
return Adest
end
1 change: 1 addition & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ include("abstractarray/ndims.jl")
include("abstractarray/permutedims.jl")
include("abstractarray/fill.jl")
include("abstractarray/mul.jl")
include("abstractarray/linearalgebra.jl")
include("array/set_types.jl")
include("array/permutedims.jl")
include("array/mul.jl")
Expand Down
11 changes: 11 additions & 0 deletions NDTensors/src/abstractarray/linearalgebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# NDTensors.qr
qr(A::AbstractMatrix) = qr(leaf_parenttype(A), A)
qr(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.qr(A)

# NDTensors.eigen
eigen(A::AbstractMatrix) = eigen(leaf_parenttype(A), A)
eigen(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.eigen(A)

# NDTensors.svd
svd(A::AbstractMatrix) = svd(leaf_parenttype(A), A)
svd(::Type{<:AbstractArray}, A::AbstractMatrix) = LinearAlgebra.svd(A)
17 changes: 15 additions & 2 deletions NDTensors/src/abstractarray/permutedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@ end

# NDTensors.permutedims!
function permutedims!(Mdest::AbstractArray, M::AbstractArray, perm)
return permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
return Mdest
end

# NDTensors.permutedims!
function permutedims!(::Type{<:AbstractArray}, Mdest, ::Type{<:AbstractArray}, M, perm)
return Base.permutedims!(Mdest, M, perm)
Base.permutedims!(Mdest, M, perm)
return Mdest
end

function permutedims!!(B::AbstractArray, A::AbstractArray, perm)
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm)
end

function permutedims!!(
Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm
)
permutedims!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm)
return B
end

function permutedims!!(B::AbstractArray, A::AbstractArray, perm, f)
Expand Down
10 changes: 10 additions & 0 deletions NDTensors/src/blocksparse/blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ function BlockSparse(
return BlockSparse(Vector{ElT}(undef, dim), blockoffsets; vargs...)
end

function BlockSparse(
datatype::Type{<:AbstractArray},
::UndefInitializer,
blockoffsets::BlockOffsets,
dim::Integer;
vargs...,
)
return BlockSparse(datatype(undef, dim), blockoffsets; vargs...)
end

function BlockSparse(blockoffsets::BlockOffsets, dim::Integer; vargs...)
return BlockSparse(Float64, blockoffsets, dim; vargs...)
end
Expand Down
44 changes: 38 additions & 6 deletions NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ function BlockSparseTensor(
return tensor(storage, inds)
end

function BlockSparseTensor(
datatype::Type{<:AbstractArray}, ::UndefInitializer, boffs::BlockOffsets, inds
)
nnz_tot = nnz(boffs, inds)
storage = BlockSparse(datatype, undef, boffs, nnz_tot)
return tensor(storage, inds)
end

function BlockSparseTensor(
::Type{ElT}, ::UndefInitializer, blocks::Vector{BlockT}, inds
) where {ElT<:Number,BlockT<:Union{Block,NTuple}}
Expand All @@ -23,6 +31,17 @@ function BlockSparseTensor(
return tensor(storage, inds)
end

function BlockSparseTensor(
datatype::Type{<:AbstractArray},
::UndefInitializer,
blocks::Vector{<:Union{Block,NTuple}},
inds,
)
boffs, nnz = blockoffsets(blocks, inds)
storage = BlockSparse(datatype, undef, boffs, nnz)
return tensor(storage, inds)
end

"""
BlockSparseTensor(::UndefInitializer, blocks, inds)

Expand Down Expand Up @@ -91,6 +110,14 @@ function BlockSparseTensor(
return tensor(storage, inds)
end

function BlockSparseTensor(
datatype::Type{<:AbstractArray}, blocks::Vector{<:Union{Block,NTuple}}, inds
)
boffs, nnz = blockoffsets(blocks, inds)
storage = BlockSparse(datatype, boffs, nnz)
return tensor(storage, inds)
end

function BlockSparseTensor(
x::Number, blocks::Vector{BlockT}, inds
) where {BlockT<:Union{Block,NTuple}}
Expand Down Expand Up @@ -426,7 +453,7 @@ function permutedims_combine_output(
# Combine the blocks (within the newly combined and permuted dimension)
blocks_perm_comb = combine_blocks(blocks_perm_comb, comb_ind_loc, blockcomb)

return BlockSparseTensor(ElT, blocks_perm_comb, is)
return BlockSparseTensor(leaf_parenttype(T), blocks_perm_comb, is)
end

function permutedims_combine(
Expand Down Expand Up @@ -487,8 +514,11 @@ function permutedims_combine(

# XXX Not sure what this was for
Rb = reshape(Rb, permute(dims(Tb), perm))
# TODO: Make this `convert` call more general
# for GPUs.
Tbₐ = convert(Array, Tb)
@strided Rb .= permutedims(Tbₐ, perm)
## @strided Rb .= permutedims(Tbₐ, perm)
permutedims!(Rb, Tbₐ, perm)
end

return R
Expand Down Expand Up @@ -564,7 +594,7 @@ function uncombine_output(
blocks_uncomb_perm = perm_blocks(blocks_uncomb, combdim, invperm(blockperm))
boffs_uncomb_perm, nnz_uncomb_perm = blockoffsets(blocks_uncomb_perm, inds_uncomb_perm)
T_uncomb_perm = tensor(
BlockSparse(ElT, boffs_uncomb_perm, nnz_uncomb_perm), inds_uncomb_perm
BlockSparse(leaf_parenttype(T), boffs_uncomb_perm, nnz_uncomb_perm), inds_uncomb_perm
)
R = reshape(T_uncomb_perm, is)
return R
Expand Down Expand Up @@ -623,14 +653,16 @@ function uncombine(
#copyto!(Rb,Tb)

if length(Tb) == 1
Rb[1] = Tb[1]
Rb[] = Tb[]
else
# XXX: this used to be:
# Rbₐᵣ = ReshapedArray(parent(Rbₐ), size(Tb), ())
# however that doesn't work with subarrays
Rbₐ = convert(Array, Rb)
Rbₐᵣ = ReshapedArray(Rbₐ, size(Tb), ())
@strided Rbₐᵣ .= Tb
## Rbₐᵣ = ReshapedArray(Rbₐ, size(Tb), ())
Rbₐᵣ = reshape(Rbₐ, size(Tb))
## @strided Rbₐᵣ .= Tb
copyto!(Rbₐᵣ, Tb)
end
end
end
Expand Down
9 changes: 9 additions & 0 deletions NDTensors/src/blocksparse/diagblocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ function DiagBlockSparse(
return DiagBlockSparse(Vector{ElT}(undef, diaglength), boffs)
end

function DiagBlockSparse(
datatype::Type{<:AbstractArray},
::UndefInitializer,
boffs::BlockOffsets,
diaglength::Integer,
)
return DiagBlockSparse(datatype(undef, diaglength), boffs)
end

function DiagBlockSparse(::UndefInitializer, boffs::BlockOffsets, diaglength::Integer)
return DiagBlockSparse(Float64, undef, boffs, diaglength)
end
Expand Down
29 changes: 19 additions & 10 deletions NDTensors/src/blocksparse/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ per row/column, otherwise it fails.
This assumption makes it so the result can be
computed from the dense svds of seperate blocks.
"""
function LinearAlgebra.svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
function svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
alg::String = get(kwargs, :alg, "divide_and_conquer")
min_blockdim::Int = get(kwargs, :min_blockdim, 0)
truncate = haskey(kwargs, :maxdim) || haskey(kwargs, :cutoff)

#@timeit_debug timer "block sparse svd" begin

Us = Vector{DenseTensor{ElT,2}}(undef, nnzblocks(T))
Ss = Vector{DiagTensor{real(ElT),2}}(undef, nnzblocks(T))
Vs = Vector{DenseTensor{ElT,2}}(undef, nnzblocks(T))
Expand Down Expand Up @@ -146,9 +147,11 @@ function LinearAlgebra.svd(T::BlockSparseMatrix{ElT}; kwargs...) where {ElT}
indsS = setindex(inds(T), dag(uind), 1)
indsS = setindex(indsS, dag(vind), 2)

U = BlockSparseTensor(ElT, undef, nzblocksU, indsU)
S = DiagBlockSparseTensor(real(ElT), undef, nzblocksS, indsS)
V = BlockSparseTensor(ElT, undef, nzblocksV, indsV)
U = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksU, indsU)
S = DiagBlockSparseTensor(
set_eltype(leaf_parenttype(T), real(ElT)), undef, nzblocksS, indsS
)
V = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksV, indsV)

for n in 1:nnzblocksT
Ub, Sb, Vb = Us[n], Ss[n], Vs[n]
Expand Down Expand Up @@ -204,7 +207,7 @@ _eigen_eltypes(T::Hermitian{ElT,<:BlockSparseMatrix{ElT}}) where {ElT} = real(El

_eigen_eltypes(T::BlockSparseMatrix{ElT}) where {ElT} = complex(ElT), complex(ElT)

function LinearAlgebra.eigen(
function eigen(
T::Union{Hermitian{ElT,<:BlockSparseMatrix{ElT}},BlockSparseMatrix{ElT}}; kwargs...
) where {ElT<:Union{Real,Complex}}
truncate = haskey(kwargs, :maxdim) || haskey(kwargs, :cutoff)
Expand Down Expand Up @@ -245,7 +248,9 @@ function LinearAlgebra.eigen(
else
Dtrunc = tensor(Diag(storage(Ds[n])[1:blockdim]), (blockdim, blockdim))
Ds[n] = Dtrunc
Vs[n] = copy(Vs[n][1:dim(Vs[n], 1), 1:blockdim])
new_size = (dim(Vs[n], 1), blockdim)
new_data = array(Vs[n])[1:new_size[1], 1:new_size[2]]
Vs[n] = tensor(Dense(new_data), new_size)
end
end
deleteat!(Ds, dropblocks)
Expand Down Expand Up @@ -299,8 +304,10 @@ function LinearAlgebra.eigen(
nzblocksV[n] = blockV
end

D = DiagBlockSparseTensor(ElD, undef, nzblocksD, indsD)
V = BlockSparseTensor(ElV, undef, nzblocksV, indsV)
D = DiagBlockSparseTensor(
set_ndims(set_eltype(leaf_parenttype(T), ElD), 1), undef, nzblocksD, indsD
)
V = BlockSparseTensor(set_eltype(leaf_parenttype(T), ElV), undef, nzblocksV, indsV)

for n in 1:nnzblocksT
Db, Vb = Ds[n], Vs[n]
Expand Down Expand Up @@ -372,14 +379,16 @@ function qx(qx::Function, T::BlockSparseTensor{<:Any,2}; kwargs...)
nzblocksX[n] = (UInt(n), blockT[2])
end

Q = BlockSparseTensor(ElT, undef, nzblocksQ, indsQ)
X = BlockSparseTensor(ElT, undef, nzblocksX, indsX)
Q = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksQ, indsQ)
X = BlockSparseTensor(leaf_parenttype(T), undef, nzblocksX, indsX)

for n in 1:nnzblocksT
blockview(Q, nzblocksQ[n]) .= Qs[n]
blockview(X, nzblocksX[n]) .= Xs[n]
end

Q = adapt(leaf_parenttype(T), Q)
X = adapt(leaf_parenttype(T), X)
return Q, X
end

Expand Down
20 changes: 18 additions & 2 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ end
#

@propagate_inbounds function getindex(T::DenseTensor{<:Number})
return (iscu(T) ? NDTensors.cpu(data(T))[] : data(T)[])
return getindex(leaf_parenttype(T), T)
end

@propagate_inbounds function getindex(::Type{<:AbstractArray}, T::DenseTensor{<:Number})
return data(T)[]
end

@propagate_inbounds function getindex(T::DenseTensor{<:Number}, I::Integer...)
Expand Down Expand Up @@ -110,6 +114,18 @@ end
return T
end

@propagate_inbounds function setindex!(T::DenseTensor{<:Number}, x::Number)
setindex!(leaf_parenttype(T), T, x)
return T
end

@propagate_inbounds function setindex!(
::Type{<:AbstractArray}, T::DenseTensor{<:Number}, x::Number
)
data(T)[] = x
return T
end

#
# Linear indexing
#
Expand Down Expand Up @@ -242,7 +258,7 @@ function permutedims!(
R::DenseTensor{<:Number,N}, T::DenseTensor{<:Number,N}, perm, f::Function
) where {N}
if nnz(R) == 1 && nnz(T) == 1
R[1] = f(R[1], T[1])
R[] = f(R[], T[])
return R
end
RA = array(R)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/dense/linearalgebra/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

# svd of an order-n tensor according to positions Lpos
# and Rpos
function LinearAlgebra.svd(
function svd(
T::DenseTensor{<:Number,N,IndsT}, Lpos::NTuple{NL,Int}, Rpos::NTuple{NR,Int}; kwargs...
) where {N,IndsT,NL,NR}
M = permute_reshape(T, Lpos, Rpos)
Expand Down
2 changes: 1 addition & 1 deletion NDTensors/src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,6 @@ import Base.Broadcast: Broadcasted, BroadcastStyle

import Adapt: adapt_structure, adapt_storage

import LinearAlgebra: diag, exp, norm, qr
import LinearAlgebra: diag, exp, norm

import TupleTools: isperm
Loading
Loading