Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ITensors][NDTensors] An attempt to address contract compile time #1125

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions NDTensors/src/dense/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,69 +328,61 @@ function _contract!(
β::Number=zero(El),
) where {El,NC,NA,NB}
# TODO: directly use Tensor instead of Array
C = ReshapedArray(data(storage(CT)), dims(inds(CT)), ())
A = ReshapedArray(data(storage(AT)), dims(inds(AT)), ())
B = ReshapedArray(data(storage(BT)), dims(inds(BT)), ())

tA = 'N'
if props.permuteA
pA = NTuple{NA,Int}(props.PA)
#@timeit_debug timer "_contract!: permutedims A" begin
@strided Ap = permutedims(A, pA)
@strided Ap = permutedims(AT, pA)
#end # @timeit
AM = ReshapedArray(Ap, (props.dmid, props.dleft), ())
tA = 'T'
AM = transpose(ReshapedArray(Ap, (props.dmid, props.dleft), ()))
else
#A doesn't have to be permuted
if Atrans(props)
AM = ReshapedArray(A.parent, (props.dmid, props.dleft), ())
tA = 'T'
AM = transpose(ReshapedArray(AT, (props.dmid, props.dleft), ()))
else
AM = ReshapedArray(A.parent, (props.dleft, props.dmid), ())
AM = ReshapedArray(AT, (props.dleft, props.dmid), ())
end
end

tB = 'N'
if props.permuteB
pB = NTuple{NB,Int}(props.PB)
#@timeit_debug timer "_contract!: permutedims B" begin
@strided Bp = permutedims(B, pB)
@strided Bp = permutedims(BT, pB)
#end # @timeit
BM = ReshapedArray(Bp, (props.dmid, props.dright), ())
else
if Btrans(props)
BM = ReshapedArray(B.parent, (props.dright, props.dmid), ())
tB = 'T'
BM = transpose(ReshapedArray(BT, (props.dright, props.dmid), ()))
else
BM = ReshapedArray(B.parent, (props.dmid, props.dright), ())
BM = ReshapedArray(BT, (props.dmid, props.dright), ())
end
end

# TODO: this logic may be wrong
if props.permuteC
# Need to copy here since we will be permuting
# into C later
CM = ReshapedArray(copy(C), (props.dleft, props.dright), ())
CM = ReshapedArray(copy(CT), (props.dleft, props.dright), ())
else
if Ctrans(props)
CM = ReshapedArray(C.parent, (props.dright, props.dleft), ())
(AM, BM) = (BM, AM)
if tA == tB
tA = tB = (tA == 'T' ? 'N' : 'T')
end
CM = transpose(ReshapedArray(CT, (props.dright, props.dleft), ()))
else
CM = ReshapedArray(C.parent, (props.dleft, props.dright), ())
CM = ReshapedArray(CT, (props.dleft, props.dright), ())
end
end

_gemm!(tA, tB, El(α), AM, BM, El(β), CM)
#tC = similar(CM)
#_gemm!(tA, tB, El(α), AM, BM, El(β), CM)
mul!(CM, AM, BM, El(α), El(β))

if props.permuteC
pC = NTuple{NC,Int}(props.PC)
Cr = ReshapedArray(CM.parent, props.newCrange, ())
# TODO: use invperm(pC) here?
#@timeit_debug timer "_contract!: permutedims C" begin
@strided C .= permutedims(Cr, pC)
@strided CT .= permutedims(Cr, pC)
#end # @timeit
end

Expand Down
1 change: 1 addition & 0 deletions src/ITensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ include("indexset.jl")
#
include("itensor.jl")
include("tensor_operations/tensor_algebra.jl")
include("tensor_operations/matrix_contract.jl")
include("tensor_operations/matrix_algebra.jl")
include("tensor_operations/permutations.jl")
include("broadcast.jl")
Expand Down
60 changes: 60 additions & 0 deletions src/tensor_operations/matrix_contract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
function matrix_contract(A::ITensor, B::ITensor)
A_inds = inds(A)
B_inds = inds(B)
NA = ndims(A)
NB = ndims(B)

labelsA, labelsB = compute_contraction_labels(inds(A), inds(B))
labelsoutput_tensor = NDTensors.contract_labels(labelsA, labelsB)

output_tensor = NDTensors.contraction_output(
NDTensors.tensor(A), labelsA, NDTensors.tensor(B), labelsB, labelsoutput_tensor
)
NC = ndims(output_tensor)
props = NDTensors.ContractionProperties(labelsA, labelsB, labelsoutput_tensor)
NDTensors.compute_contraction_properties!(props, A, B, output_tensor)

dmid = Index(props.dmid)
dleft = Index(props.dleft)
dright = Index(props.dright)

vA = NDTensors.data(storage(A))
vB = NDTensors.data(storage(B))

if props.permuteA
pA = NTuple{NA,Int}(props.PA)
pAinds = map(x -> ind(A, x), pA)
#@timeit_debug timer "_contract!: permutedims A" begin
#Ap = permutedims(NDTensors.tensor(A), pA)
#end # @timeit
AM = setinds!(permute(A, pAinds), (dmid, dleft))
else
#A doesn't have to be permuted
# Don't do the transpose here, do it later in the regular contract function
if NDTensors.Atrans(props)
# println("transposing")
# tranA = transpose(NDTensors.ReshapedArray(NDTensors.data(storage(A)), (props.dmid, props.dleft), ()))
AM = itensor(storage(A), (dmid, dleft))
# @show AM
else
AM = itensor(storage(A), (dleft, dmid))
end
end

if props.permuteB
pB = NTuple{NB,Int}(props.PB)
pBinds = map(x -> ind(B, x), pB)
#@timeit_debug timer "_contract!: permutedims B" begin
#end # @timeit
BM = setinds!(permute(B, pBinds), (dmid, dright))
else
if NDTensors.Btrans(props)
BM = itensor(storage(B), (dright, dmid))
else
BM = itensor(storage(B), (dmid, dright))
end
end

output_tensor = setinds!(_contract(AM, BM), inds(output_tensor))
return output_tensor
end
8 changes: 6 additions & 2 deletions src/tensor_operations/tensor_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function (A::ITensor * B::ITensor)::ITensor
return contract(A, B)
end

function contract(A::ITensor, B::ITensor)::ITensor
function contract(A::ITensor, B::ITensor; matricize=false)::ITensor
NA::Int = ndims(A)
NB::Int = ndims(B)
if NA == 0 && NB == 0
Expand All @@ -101,7 +101,11 @@ function contract(A::ITensor, B::ITensor)::ITensor
elseif NB == 0
return iscombiner(B) ? _contract(B, A) : B[] * A
else
C = using_combine_contract() ? combine_contract(A, B) : _contract(A, B)
if matricize
C = matrix_contract(A, B)
else
C = using_combine_contract() ? combine_contract(A, B) : _contract(A, B)
end
return C
end
end
Expand Down
30 changes: 30 additions & 0 deletions test/base/test_contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,101 +31,118 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
@testset "Test contract ITensor (Scalar*Scalar -> Scalar)" begin
C = A * B
@test scalar(C) ≈ scalar(A) * scalar(B)
@test C == contract(A, B; matricize=true)
end
@testset "Test contract ITensor (Scalar*Vector -> Vector)" begin
C = A * Ai
@test array(C) ≈ scalar(A) * array(Ai)
@test C == contract(A, Ai; matricize=true)
end
@testset "Test contract ITensor (Vector*Scalar -> Vector)" begin
C = Aj * A
@test array(C) ≈ scalar(A) * array(Aj)
@test C == contract(Aj, A; matricize=true)
end
@testset "Test contract ITensors (Vectorᵀ*Vector -> Scalar)" begin
C = Ai * Bi
CArray = transpose(array(Ai)) * array(Bi)
@test CArray ≈ scalar(C)
@test C == contract(Ai, Bi; matricize=true)
end
@testset "Test contract ITensors (Vector*Vectorᵀ -> Matrix)" begin
C = Ai * Aj
for ii in 1:dim(i), jj in 1:dim(j)
@test C[i => ii, j => jj] ≈ Ai[i => ii] * Aj[j => jj]
end
@test C == contract(Ai, Aj; matricize=true)
end
@testset "Test contract ITensors (Matrix*Scalar -> Matrix)" begin
Aij = permute(Aij, i, j)
C = Aij * A
@test array(permute(C, i, j)) ≈ scalar(A) * array(Aij)
@test C == contract(Aij, A; matricize=true)
end
@testset "Test contract ITensors (Matrix*Vector -> Vector)" begin
Aij = permute(Aij, i, j)
C = Aij * Aj
CArray = array(permute(Aij, i, j)) * array(Aj)
@test CArray ≈ array(C)
@test C == contract(Aij, Aj; matricize=true)
end
@testset "Test contract ITensors (Matrixᵀ*Vector -> Vector)" begin
Aij = permute(Aij, j, i)
C = Aij * Aj
CArray = transpose(array(Aij)) * array(Aj)
@test CArray ≈ array(C)
@test C == contract(Aij, Aj; matricize=true)
end
@testset "Test contract ITensors (Vector*Matrix -> Vector)" begin
Aij = permute(Aij, i, j)
C = Ai * Aij
CArray = transpose(transpose(array(Ai)) * array(Aij))
@test CArray ≈ array(C)
@test C == contract(Ai, Aij; matricize=true)
end
@testset "Test contract ITensors (Vector*Matrixᵀ -> Vector)" begin
Aij = permute(Aij, j, i)
C = Ai * Aij
CArray = transpose(transpose(array(Ai)) * transpose(array(Aij)))
@test CArray ≈ array(C)
@test C == contract(Ai, Aij; matricize=true)
end
@testset "Test contract ITensors (Matrix*Matrix -> Scalar)" begin
Aij = permute(Aij, i, j)
Bij = permute(Bij, i, j)
C = Aij * Bij
CArray = LinearAlgebra.tr(array(Aij) * transpose(array(Bij)))
@test CArray ≈ scalar(C)
@test C == contract(Aij, Bij; matricize=true)
end
@testset "Test contract ITensors (Matrix*Matrix -> Matrix)" begin
Aij = permute(Aij, i, j)
Ajk = permute(Ajk, j, k)
C = Aij * Ajk
CArray = array(Aij) * array(Ajk)
@test CArray ≈ array(C)
@test C == contract(Aij, Ajk; matricize=true)
end
@testset "Test contract ITensors (Matrixᵀ*Matrix -> Matrix)" begin
Aij = permute(Aij, j, i)
Ajk = permute(Ajk, j, k)
C = Aij * Ajk
CArray = transpose(array(Aij)) * array(Ajk)
@test CArray ≈ array(C)
@test C == contract(Aij, Ajk; matricize=true)
end
@testset "Test contract ITensors (Matrix*Matrixᵀ -> Matrix)" begin
Aij = permute(Aij, i, j)
Ajk = permute(Ajk, k, j)
C = Aij * Ajk
CArray = array(Aij) * transpose(array(Ajk))
@test CArray ≈ array(C)
@test C == contract(Aij, Ajk; matricize=true)
end
@testset "Test contract ITensors (Matrixᵀ*Matrixᵀ -> Matrix)" begin
Aij = permute(Aij, j, i)
Ajk = permute(Ajk, k, j)
C = Aij * Ajk
CArray = transpose(array(Aij)) * transpose(array(Ajk))
@test CArray ≈ array(C)
@test C == contract(Aij, Ajk; matricize=true)
end
@testset "Test contract ITensors (Matrix⊗Matrix -> 4-tensor)" begin
C = Aij * Akl
for ii in 1:dim(i), jj in 1:dim(j), kk in 1:dim(k), ll in 1:dim(l)
@test C[i => ii, j => jj, k => kk, l => ll] ≈
Aij[i => ii, j => jj] * Akl[k => kk, l => ll]
end
@test C == contract(Aij, Akl; matricize=true)
end
@testset "Test contract ITensors (3-Tensor*Scalar -> 3-Tensor)" begin
Aijk = permute(Aijk, i, j, k)
C = Aijk * A
@test array(permute(C, i, j, k)) ≈ scalar(A) * array(Aijk) rtol = 1e-12
@test C == contract(Aijk, A; matricize=true)
end
@testset "Test contract ITensors (3-Tensor*Vector -> Matrix)" begin
Aijk = permute(Aijk, i, j, k)
Expand All @@ -136,6 +153,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
dim(k),
)
@test CArray ≈ array(permute(C, j, k))
@test C == contract(Aijk, Ai; matricize=true)
end
@testset "Test contract ITensors (Vector*3-Tensor -> Matrix)" begin
Aijk = permute(Aijk, i, j, k)
Expand All @@ -147,6 +165,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
dim(k),
)
@test CArray ≈ array(permute(C, i, k))
@test C == contract(Aj, Aijk; matricize=true)
end
@testset "Test contract ITensors (3-Tensor*Matrix -> Vector)" begin
Aijk = permute(Aijk, i, j, k)
Expand All @@ -155,6 +174,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
CArray =
reshape(array(permute(Aijk, j, i, k)), dim(j), dim(i) * dim(k)) * vec(array(Aik))
@test CArray ≈ array(C)
@test C == contract(Aijk, Aik; matricize=true)
end
@testset "Test contract ITensors (3-Tensor*Matrix -> 3-Tensor)" begin
Aijk = permute(Aijk, i, j, k)
Expand All @@ -167,6 +187,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
dim(l),
)
@test CArray ≈ array(permute(C, i, k, l))
@test C == contract(Aijk, Ajl; matricize=true)
end
@testset "Test contract ITensors (Matrix*3-Tensor -> 3-Tensor)" begin
Aijk = permute(Aijk, i, j, k)
Expand All @@ -180,6 +201,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
dim(j),
)
@test CArray ≈ array(permute(C, l, i, j))
@test C == contract(Akl, Aijk; matricize=true)
end
@testset "Test contract ITensors (3-Tensor*3-Tensor -> 3-Tensor)" begin
Aijk = permute(Aijk, i, j, k)
Expand All @@ -189,6 +211,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
reshape(array(permute(Aijk, i, j, k)), dim(i), dim(j) * dim(k)) *
reshape(array(permute(Ajkl, j, k, l)), dim(j) * dim(k), dim(l))
@test CArray ≈ array(permute(C, i, l))
@test C == contract(Aijk, Ajkl; matricize=true)
end
@testset "Test contract ITensors (3-Tensor*3-Tensor -> 3-Tensor)" begin
for inds_ijk in Combinatorics.permutations([i, j, k]),
Expand All @@ -201,6 +224,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
reshape(array(permute(Ajkl, l, j, k)), dim(l), dim(j) * dim(k)) *
reshape(array(permute(Aijk, j, k, i)), dim(j) * dim(k), dim(i))
@test CArray ≈ array(permute(C, l, i))
@test C == contract(Ajkl, Aijk; matricize=true)
end
end
@testset "Test contract ITensors (4-Tensor*3-Tensor -> 1-Tensor)" begin
Expand All @@ -214,6 +238,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
reshape(array(permute(Ajkl, j, k, l)), 1, dim(j) * dim(k) * dim(l)) *
reshape(array(permute(Aijkl, j, k, l, i)), dim(j) * dim(k) * dim(l), dim(i))
@test vec(CArray) ≈ array(permute(C, i))
@test C == contract(Ajkl, Aijkl; matricize=true)
end
end
@testset "Test contract ITensors (4-Tensor*3-Tensor -> 3-Tensor)" begin
Expand All @@ -231,6 +256,7 @@ digits(::Type{T}, i, j, k) where {T} = T(i * 10^2 + j * 10 + k)
dim(j),
)
@test CArray ≈ array(permute(C, α, i, j))
@test C == contract(Aklα, Aijkl; matricize=true)
end
end
end # End contraction testset
Expand Down Expand Up @@ -265,6 +291,7 @@ end
B = randomITensor(ComplexF64, j, k)
C = A * B
@test array(permute(C, i, k)) ≈ array(A) * array(B)
@test C == contract(A, B; matricize=true)
end
@testset "Complex ITensor * Real ITensor" begin
i = Index(2, "i")
Expand All @@ -274,6 +301,7 @@ end
B = randomITensor(Float64, j, k)
C = A * B
@test array(permute(C, i, k)) ≈ array(A) * array(B)
@test C == contract(A, B; matricize=true)
end

@testset "Outer Product Real ITensor * Complex ITensor" begin
Expand All @@ -283,6 +311,7 @@ end
B = randomITensor(ComplexF64, j)
C = A * B
@test array(permute(C, i, j)) ≈ kron(array(A), transpose(array(B)))
@test C == contract(A, B; matricize=true)
end

@testset "Outer Product: Complex ITensor * Real ITensor" begin
Expand All @@ -292,6 +321,7 @@ end
B = randomITensor(Float64, j)
C = A * B
@test array(permute(C, i, j)) ≈ kron(array(A), transpose(array(B)))
@test C == contract(A, B; matricize=true)
end
end

Expand Down