Skip to content

Commit

Permalink
[NDTensors] Fix scalar indexing issue for Diag broadcast on GPU (#1497)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT authored Jun 21, 2024
1 parent 984d814 commit a1e7ec5
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
module NDTensorsGPUArraysCoreExt
include("contract.jl")
include("blocksparsetensor.jl")
end
26 changes: 26 additions & 0 deletions NDTensors/ext/NDTensorsGPUArraysCoreExt/blocksparsetensor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using GPUArraysCore: @allowscalar, AbstractGPUArray
using NDTensors: NDTensors, BlockSparseTensor, dense, diag, map_diag!
using NDTensors.DiagonalArrays: diaglength
using NDTensors.Expose: Exposed, unexpose

## TODO to circumvent issues with blocksparse and scalar indexing
## convert blocksparse GPU tensors to dense tensors and call diag
## copying will probably have some impact on timing but this code
## currently isn't used in the main code, just in tests.
function NDTensors.diag(ETensor::Exposed{<:AbstractGPUArray,<:BlockSparseTensor})
return diag(dense(unexpose(ETensor)))
end

## TODO scalar indexing is slow here
function NDTensors.map_diag!(
f::Function,
exposed_t_destination::Exposed{<:AbstractGPUArray,<:BlockSparseTensor},
exposed_t_source::Exposed{<:AbstractGPUArray,<:BlockSparseTensor},
)
t_destination = unexpose(exposed_t_destination)
t_source = unexpose(exposed_t_source)
@allowscalar for i in 1:diaglength(t_destination)
NDTensors.setdiagindex!(t_destination, f(NDTensors.getdiagindex(t_source, i)), i)
end
return t_destination
end
26 changes: 25 additions & 1 deletion NDTensors/src/blocksparse/blocksparsetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ end
# Returns the offset of the new block added.
# XXX rename to insertblock!, no need to return offset
using .TypeParameterAccessors: unwrap_array_type
using .Expose: expose
using .Expose: Exposed, expose, unexpose
function insertblock_offset!(T::BlockSparseTensor{ElT,N}, newblock::Block{N}) where {ElT,N}
newdim = blockdim(T, newblock)
newoffset = nnz(T)
Expand Down Expand Up @@ -356,6 +356,30 @@ function dense(T::TensorT) where {TensorT<:BlockSparseTensor}
return tensor(Dense(r), inds(T))
end

function diag(ETensor::Exposed{<:AbstractArray,<:BlockSparseTensor})
tensor = unexpose(ETensor)
tensordiag = NDTensors.similar(
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
)
for j in 1:diaglength(tensor)
@inbounds tensordiag[j] = getdiagindex(tensor, j)
end
return tensordiag
end

## TODO currently this fails on GPU with scalar indexing
function map_diag!(
f::Function,
exposed_t_destination::Exposed{<:AbstractArray,<:BlockSparseTensor},
exposed_t_source::Exposed{<:AbstractArray,<:BlockSparseTensor},
)
t_destination = unexpose(exposed_t_destination)
t_source = unexpose(exposed_t_source)
for i in 1:diaglength(t_destination)
NDTensors.setdiagindex!(t_destination, f(NDTensors.getdiagindex(t_source, i)), i)
end
return t_destination
end
#
# Operations
#
Expand Down
6 changes: 6 additions & 0 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ convert(::Type{Array}, T::DenseTensor) = reshape(data(storage(T)), dims(inds(T))
# Useful for using Base Array functions
array(T::DenseTensor) = convert(Array, T)

using .DiagonalArrays: DiagonalArrays, diagview

function DiagonalArrays.diagview(T::DenseTensor)
return diagview(array(T))
end

function Array{ElT,N}(T::DenseTensor{ElT,N}) where {ElT,N}
return copy(array(T))
end
Expand Down
43 changes: 10 additions & 33 deletions NDTensors/src/diag/diagtensor.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using .DiagonalArrays: diaglength
using .DiagonalArrays: diaglength, diagview

const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag}
const NonuniformDiagTensor{ElT,N,StoreT,IndsT} =
Expand All @@ -9,9 +9,7 @@ const UniformDiagTensor{ElT,N,StoreT,IndsT} =
function diag(tensor::DiagTensor)
tensor_diag = NDTensors.similar(dense(typeof(tensor)), (diaglength(tensor),))
# TODO: Define `eachdiagindex`.
for j in 1:diaglength(tensor)
tensor_diag[j] = getdiagindex(tensor, j)
end
diagview(tensor_diag) .= diagview(tensor)
return tensor_diag
end

Expand All @@ -33,6 +31,10 @@ function Array(T::DiagTensor{ElT,N}) where {ElT,N}
return Array{ElT,N}(T)
end

function DiagonalArrays.diagview(T::NonuniformDiagTensor)
return data(T)
end

function zeros(tensortype::Type{<:DiagTensor}, inds)
return tensor(generic_zeros(storagetype(tensortype), mindim(inds)), inds)
end
Expand Down Expand Up @@ -110,32 +112,11 @@ end
using .TypeParameterAccessors: unwrap_array_type
# convert to Dense
function dense(T::DiagTensor)
return dense(unwrap_array_type(T), T)
end

# CPU version
function dense(::Type{<:Array}, T::DiagTensor)
R = zeros(dense(typeof(T)), inds(T))
for i in 1:diaglength(T)
setdiagindex!(R, getdiagindex(T, i), i)
end
diagview(R) .= diagview(T)
return R
end

# GPU version
function dense(::Type{<:AbstractArray}, T::DiagTensor)
D_cpu = dense(Array, cpu(T))
return adapt(unwrap_array_type(T), D_cpu)
end

# UniformDiag version
# TODO: Delete once new DiagonalArray is designed.
# TODO: This creates a tensor on CPU by default so may cause
# problems for GPU.
function dense(::Type{<:Number}, T::DiagTensor)
return dense(Tensor(Diag(fill(getdiagindex(T, 1), diaglength(T))), inds(T)))
end

denseblocks(T::DiagTensor) = dense(T)

function permutedims!(
Expand All @@ -145,16 +126,14 @@ function permutedims!(
f::Function=(r, t) -> t,
) where {N}
# TODO: check that inds(R)==permute(inds(T),perm)?
for i in 1:diaglength(R)
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
end
diagview(R) .= f.(diagview(R), diagview(T))
return R
end

function permutedims(
T::DiagTensor{<:Number,N}, perm::NTuple{N,Int}, f::Function=identity
) where {N}
R = NDTensors.similar(T, permute(inds(T), perm))
R = NDTensors.similar(T)
g(r, t) = f(t)
permutedims!(R, T, perm, g)
return R
Expand Down Expand Up @@ -193,9 +172,7 @@ end
function permutedims!(
R::DenseTensor{ElR,N}, T::DiagTensor{ElT,N}, perm::NTuple{N,Int}, f::Function=(r, t) -> t
) where {ElR,ElT,N}
for i in 1:diaglength(T)
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
end
diagview(R) .= f.(diagview(R), diagview(T))
return R
end

Expand Down
1 change: 0 additions & 1 deletion NDTensors/src/linearalgebra/linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,6 @@ matrix is unique. Returns a tuple (Q,R).
function qr_positive(M::AbstractMatrix)
sparseQ, R = qr(M)
Q = convert(typeof(R), sparseQ)
nc = size(Q, 2)
signs = nonzero_sign.(diag(R))
Q = Q * Diagonal(signs)
R = Diagonal(conj.(signs)) * R
Expand Down
16 changes: 12 additions & 4 deletions NDTensors/src/tensor/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,18 @@ function getdiagindex(T::Tensor{<:Number,N}, ind::Int) where {N}
return getindex(T, CartesianIndex(ntuple(_ -> ind, Val(N))))
end

using .Expose: Exposed, expose, unexpose
# TODO: add support for off-diagonals, return
# block sparse vector instead of dense.
function diag(tensor::Tensor)
diag(tensor::Tensor) = diag(expose(tensor))

function diag(ETensor::Exposed)
tensor = unexpose(ETensor)
## d = NDTensors.similar(T, ElT, (diaglength(T),))
tensordiag = NDTensors.similar(
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
)
for n in 1:diaglength(tensor)
tensordiag[n] = tensor[n, n]
end
array(tensordiag) .= diagview(tensor)
return tensordiag
end

Expand All @@ -384,6 +386,12 @@ function setdiagindex!(T::Tensor{<:Number,N}, val, ind::Int) where {N}
return T
end

function map_diag!(f::Function, exposed_t_destination::Exposed, exposed_t_source::Exposed)
diagview(unexpose(exposed_t_destination)) .= f.(diagview(unexpose(exposed_t_source)))
return unexpose(exposed_t_destination)
end
map_diag(f::Function, t::Tensor) = map_diag!(f, expose(copy(t)), expose(t))

#
# Some generic contraction functionality
#
Expand Down
8 changes: 8 additions & 0 deletions NDTensors/test/test_blocksparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using NDTensors:
blockview,
data,
dense,
diag,
diaglength,
dims,
eachnzblock,
inds,
Expand Down Expand Up @@ -52,6 +54,8 @@ using Test: @test, @test_throws, @testset
@test isblocknz(A, (1, 2))
@test !isblocknz(A, (1, 1))
@test !isblocknz(A, (2, 2))
dA = diag(A)
@test @allowscalar dA diag(dense(A))

# Test different ways of getting nnz
@test nnz(blockoffsets(A), inds(A)) == nnz(A)
Expand Down Expand Up @@ -104,6 +108,10 @@ using Test: @test, @test_throws, @testset
@allowscalar for I in eachindex(C)
@test C[I] == A[I] + B[I]
end
Cp = NDTensors.map_diag(i -> 2 * i, C)
@allowscalar for i in 1:diaglength(Cp)
@test Cp[i, i] == 2 * C[i, i]
end

Ap = permutedims(A, (2, 1))

Expand Down
4 changes: 4 additions & 0 deletions NDTensors/test/test_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ NDTensors.dim(i::MyInd) = i.dim
randn!(B)
C = copy(A)
C = permutedims!!(C, B, (1, 2), +)
Cp = NDTensors.map_diag(i -> 2 * i, C)
@allowscalar for i in 1:diaglength(Cp)
@test Cp[i, i] == 2 * C[i, i]
end

Ap = permutedims(A, (2, 1))
@allowscalar begin
Expand Down
24 changes: 22 additions & 2 deletions NDTensors/test/test_diag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,39 @@ using LinearAlgebra: dot
D = Tensor(Diag(1), (2, 2))
@test norm(D) == 2
d = 3
## TODO this fails because uniform diag tensors are immutable
#S = NDTensors.map_diag((i->i * 2), dev(D))
# @allowscalar for i in 1:diaglength(S)
# @test S[i,i] == 2.0 * D[i,i]
# end

vr = rand(elt, d)
D = dev(tensor(Diag(vr), (d, d)))
Da = Array(D)
Dm = Matrix(D)
Da = permutedims(D, (2, 1))
@allowscalar begin
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)

## TODO Currently this permutedims requires scalar indexing on GPU.
Da = permutedims(D, (2, 1))
@test Da == D
end

# This if statement corresponds to the reported bug:
# https://github.com/JuliaGPU/Metal.jl/issues/364
if !(dev == NDTensors.mtl && elt === ComplexF32)
S = permutedims(dev(D), (1, 2), sqrt)
@allowscalar begin
for i in 1:diaglength(S)
@test S[i, i] sqrt(D[i, i])
end
end
end
S = NDTensors.map_diag(i -> 2 * i, dev(D))
@allowscalar for i in 1:diaglength(S)
@test S[i, i] == 2 * D[i, i]
end

# Regression test for https://github.com/ITensor/ITensors.jl/issues/1199
S = dev(tensor(Diag(randn(elt, 2)), (2, 2)))
## This was creating a `Dense{ReshapedArray{Adjoint{Matrix}}}` which, in mul!, was
Expand Down
6 changes: 3 additions & 3 deletions jenkins/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pipeline {
}
steps {
sh '''
julia -e 'using Pkg; Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
julia -e 'using Pkg; Pkg.Registry.update(); Pkg.update(); Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
'''
}
}
Expand All @@ -51,7 +51,7 @@ pipeline {
}
steps {
sh '''
julia -e 'using Pkg; Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
julia -e 'using Pkg; Pkg.Registry.update(); Pkg.update(); Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cuda"])'
'''
}
}
Expand All @@ -75,7 +75,7 @@ pipeline {
}
steps {
sh '''
julia -e 'using Pkg; Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cutensor"])'
julia -e 'using Pkg; Pkg.Registry.update(); Pkg.update(); Pkg.activate(temp=true); Pkg.develop(path="./NDTensors"); Pkg.develop(path="."); Pkg.test("NDTensors"; test_args=["cutensor"])'
'''
}
}
Expand Down

0 comments on commit a1e7ec5

Please sign in to comment.