From bde7da74191965a0247d932585ead6254b0813e7 Mon Sep 17 00:00:00 2001 From: Karl Pierce Date: Wed, 25 Oct 2023 11:43:37 -0400 Subject: [PATCH] [NDTensors] Fix CPU performance issue caused by bad mul dispatch (#1218) --- NDTensors/src/abstractarray/iswrappedarray.jl | 1 + NDTensors/src/array/permutedims.jl | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/NDTensors/src/abstractarray/iswrappedarray.jl b/NDTensors/src/abstractarray/iswrappedarray.jl index e1be8987e5..e601e478a2 100644 --- a/NDTensors/src/abstractarray/iswrappedarray.jl +++ b/NDTensors/src/abstractarray/iswrappedarray.jl @@ -33,6 +33,7 @@ parenttype(::Type{<:UnitUpperTriangular{<:Any,P}}) where {P} = P parenttype(::Type{<:UnitLowerTriangular{<:Any,P}}) where {P} = P parenttype(::Type{<:Diagonal{<:Any,P}}) where {P} = P parenttype(::Type{<:SubArray{<:Any,<:Any,P}}) where {P} = P +parenttype(::Type{<:StridedView{<:Any,<:Any,P}}) where {P} = P # For working with instances, not used by # `SimpleTraits.jl` traits dispatch. diff --git a/NDTensors/src/array/permutedims.jl b/NDTensors/src/array/permutedims.jl index 7a033201fb..7a5c52dc14 100644 --- a/NDTensors/src/array/permutedims.jl +++ b/NDTensors/src/array/permutedims.jl @@ -1,6 +1,9 @@ # NDTensors.permutedims function permutedims(::Type{<:Array}, M, perm) - return @strided Base.permutedims(M, perm) + ## Creating Mperm here to evaluate the permutation and + ## avoid returning a Stridedview + @strided Mperm = Base.permutedims(M, perm) + return Mperm end # NDTensors.permutedims!