Skip to content

Commit

Permalink
Do outer tensor-vector products more efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmul1114 committed Nov 28, 2023
1 parent a933580 commit 060173a
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/gcp-opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,22 @@ function mttkrp(X, U, n)
# Compute tensor-vector products right to left (equations 15, 17) for each rank
for j in 1:r
# Inner tensor-vector products
Rn_j = reshape(reshape(X, Jn, Kn) * reduce(kron, [view(U[i], :, j) for i in reverse(n+1:N)], init=1), size(X)[1:n])

inner = reshape(reshape(X, Jn, Kn) * reduce(kron, [view(U[i], :, j) for i in reverse(n+1:N)], init=1), size(X)[1:n])
# Outer tensor-vector products
Jn_inner = prod(size(inner)[1:n-1])
Kn_inner = prod(size(inner)[n:end])
Rn[:, j] = transpose(reshape(inner, Jn_inner, Kn_inner)) * reduce(kron, [view(U[i], :, j) for i in reverse(1:n-1)],
init=1)
# Permute so dims to be multiplied are last
Rn_j = permutedims(Rn_j, [n; 1:n-1])
#Rn_j = permutedims(Rn_j, [n; 1:n-1])
# Multiply from right to left
sz = size(Rn_j)
m = length(sz)
for k in n-1:-1:1
Rn_j = reshape(Rn_j, prod(sz[1:m-1]), sz[m]) * U[k][:, j]
m -= 1
end
Rn[:, j] = Rn_j
#sz = size(Rn_j)
#m = length(sz)
#for k in n-1:-1:1
# Rn_j = reshape(Rn_j, prod(sz[1:m-1]), sz[m]) * U[k][:, j]
# m -= 1
#end
#Rn[:, j] = Rn_j
end
return Rn
end

0 comments on commit 060173a

Please sign in to comment.