From 0dd361c5eea135eed3405042aa799b4812e00b84 Mon Sep 17 00:00:00 2001 From: kmp5VT Date: Wed, 11 Oct 2023 15:42:24 -0400 Subject: [PATCH] update permutedim calls --- NDTensors/src/arraytensor/array.jl | 2 +- NDTensors/src/dense/densetensor.jl | 5 ++--- NDTensors/src/dense/tensoralgebra/contract.jl | 12 ++++++------ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/NDTensors/src/arraytensor/array.jl b/NDTensors/src/arraytensor/array.jl index c2950b2c20..8abe19d74c 100644 --- a/NDTensors/src/arraytensor/array.jl +++ b/NDTensors/src/arraytensor/array.jl @@ -61,6 +61,6 @@ end function permutedims!( output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function ) - output_array .= f.(output_array, permutedims!!(leaf_parenttype(array), array, perm)) + output_array = permutedims!!(leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f) return output_array end diff --git a/NDTensors/src/dense/densetensor.jl b/NDTensors/src/dense/densetensor.jl index 0743ca4ae0..d3d5582fcb 100644 --- a/NDTensors/src/dense/densetensor.jl +++ b/NDTensors/src/dense/densetensor.jl @@ -199,7 +199,7 @@ function permutedims!( ) where {N,StoreT<:StridedArray} RA = array(R) TA = array(T) - RA .= permutedims!!(leaf_parenttype(TA), TA, perm) + RA .= permutedims!(leaf_parenttype(RA), RA, leaf_parenttype(TA), TA, perm) return R end @@ -247,8 +247,7 @@ function permutedims!( end RA = array(R) TA = array(T) - RA .= f.(RA, permutedims!!(leaf_parenttype(TA), TA, perm)) - return R + return permutedims!!(leaf_parenttype(RA), RA, leaf_parenttype(TA), TA, perm, f) end """ diff --git a/NDTensors/src/dense/tensoralgebra/contract.jl b/NDTensors/src/dense/tensoralgebra/contract.jl index 8b764c1913..6b1ceef1e8 100644 --- a/NDTensors/src/dense/tensoralgebra/contract.jl +++ b/NDTensors/src/dense/tensoralgebra/contract.jl @@ -175,14 +175,14 @@ function _contract_scalar_perm!( if iszero(α) fill!(Rᵃ, 0) else - Rᵃ .= α .* permutedims!!(leaf_parenttype(Tᵃ), Tᵃ, perm) + Rᵃ = permutedims!!(leaf_parenttype(Rᵃ), Rᵃ, leaf_parenttype(Tᵃ), Tᵃ, perm, (r,t) -> *(α, t)) end elseif isone(β) if iszero(α) # Rᵃ .= Rᵃ # No-op else - Rᵃ .= α .* permutedims!!(leaf_parenttype(Tᵃ), Tᵃ, perm) .+ Rᵃ + Rᵃ = permutedims!!(leaf_parenttype(Rᵃ), Rᵃ, leaf_parenttype(Tᵃ), Tᵃ, perm, (r,t) -> r + α * t) end else if iszero(α) @@ -346,7 +346,7 @@ function _contract!( t = TimerOutput() if props.permuteA #@timeit_debug timer "_contract!: permutedims A" begin - Ap = permutedims!!(leaf_parenttype(AT), AT, props.PA) + Ap = permutedims(leaf_parenttype(AT), AT, props.PA) #end # @timeit AM = transpose(reshape(Ap, (props.dmid, props.dleft))) else @@ -361,7 +361,7 @@ function _contract!( tB = 'N' if props.permuteB #@timeit_debug timer "_contract!: permutedims B" begin - Bp = permutedims!!(leaf_parenttype(BT), BT, props.PB) + Bp = permutedims(leaf_parenttype(BT), BT, props.PB) #end # @timeit BM = reshape(Bp, (props.dmid, props.dright)) else @@ -379,7 +379,7 @@ function _contract!( # ordering as A B which is the inverse of props.PC if β ≠ 0 CM = reshape( - permutedims!!(leaf_parenttype(CT), CT, invperm(props.PC)), + permutedims(leaf_parenttype(CT), CT, invperm(props.PC)), (props.dleft, props.dright), ) else @@ -405,7 +405,7 @@ function _contract!( Cr = reshape(CM, props.newCrange) # TODO: use invperm(pC) here? #@timeit_debug timer "_contract!: permutedims C" begin - CT .= permutedims!!(leaf_parenttype(Cr), Cr, props.PC) + CT .= permutedims(leaf_parenttype(Cr), Cr, props.PC) #end # @timeit end