Skip to content

Commit

Permalink
update permutedim calls
Browse files Browse the repository at this point in the history
  • Loading branch information
kmp5VT committed Oct 11, 2023
1 parent 86bd294 commit 0dd361c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion NDTensors/src/arraytensor/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ end
function permutedims!(
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
)
output_array .= f.(output_array, permutedims!!(leaf_parenttype(array), array, perm))
output_array = permutedims!!(leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f)
return output_array
end
5 changes: 2 additions & 3 deletions NDTensors/src/dense/densetensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function permutedims!(
) where {N,StoreT<:StridedArray}
RA = array(R)
TA = array(T)
RA .= permutedims!!(leaf_parenttype(TA), TA, perm)
RA .= permutedims!(leaf_parenttype(RA), RA, leaf_parenttype(TA), TA, perm)
return R
end

Expand Down Expand Up @@ -247,8 +247,7 @@ function permutedims!(
end
RA = array(R)
TA = array(T)
RA .= f.(RA, permutedims!!(leaf_parenttype(TA), TA, perm))
return R
return permutedims!!(leaf_parenttype(RA), RA, leaf_parenttype(TA), TA, perm, f)
end

"""
Expand Down
12 changes: 6 additions & 6 deletions NDTensors/src/dense/tensoralgebra/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,14 @@ function _contract_scalar_perm!(
if iszero(α)
fill!(Rᵃ, 0)
else
Rᵃ .= α .* permutedims!!(leaf_parenttype(Tᵃ), Tᵃ, perm)
Rᵃ = permutedims!!(leaf_parenttype(Rᵃ), Rᵃ, leaf_parenttype(Tᵃ), Tᵃ, perm, (r,t) -> *(α, t))
end
elseif isone(β)
if iszero(α)
# Rᵃ .= Rᵃ
# No-op
else
Rᵃ .= α .* permutedims!!(leaf_parenttype(Tᵃ), Tᵃ, perm) .+ Rᵃ
Rᵃ = permutedims!!(leaf_parenttype(Rᵃ), Rᵃ, leaf_parenttype(Tᵃ), Tᵃ, perm, (r,t) -> r + α * t)
end
else
if iszero(α)
Expand Down Expand Up @@ -346,7 +346,7 @@ function _contract!(
t = TimerOutput()
if props.permuteA
#@timeit_debug timer "_contract!: permutedims A" begin
Ap = permutedims!!(leaf_parenttype(AT), AT, props.PA)
Ap = permutedims(leaf_parenttype(AT), AT, props.PA)
#end # @timeit
AM = transpose(reshape(Ap, (props.dmid, props.dleft)))
else
Expand All @@ -361,7 +361,7 @@ function _contract!(
tB = 'N'
if props.permuteB
#@timeit_debug timer "_contract!: permutedims B" begin
Bp = permutedims!!(leaf_parenttype(BT), BT, props.PB)
Bp = permutedims(leaf_parenttype(BT), BT, props.PB)
#end # @timeit
BM = reshape(Bp, (props.dmid, props.dright))
else
Expand All @@ -379,7 +379,7 @@ function _contract!(
# ordering as A B which is the inverse of props.PC
if β 0
CM = reshape(
permutedims!!(leaf_parenttype(CT), CT, invperm(props.PC)),
permutedims(leaf_parenttype(CT), CT, invperm(props.PC)),
(props.dleft, props.dright),
)
else
Expand All @@ -405,7 +405,7 @@ function _contract!(
Cr = reshape(CM, props.newCrange)
# TODO: use invperm(pC) here?
#@timeit_debug timer "_contract!: permutedims C" begin
CT .= permutedims!!(leaf_parenttype(Cr), Cr, props.PC)
CT .= permutedims(leaf_parenttype(Cr), Cr, props.PC)
#end # @timeit
end

Expand Down

0 comments on commit 0dd361c

Please sign in to comment.