From da3e86580a8db4ca8a3d58a9fdeff6389676efef Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 7 May 2024 12:37:48 +0530 Subject: [PATCH] Use a mutable copy of input if inplace scaling is required (#243) * Use a mutable copy of input if inplace scaling is required * Convert to Array instead of using similar * Add test * Fix type-signature of _plan_mul --- src/chebyshevtransform.jl | 8 +++++--- test/chebyshevtests.jl | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/chebyshevtransform.jl b/src/chebyshevtransform.jl index e98f4ed5..2e58876d 100644 --- a/src/chebyshevtransform.jl +++ b/src/chebyshevtransform.jl @@ -49,8 +49,9 @@ end # convert x if necessary -@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::StridedArray{T}) where T = mul!(y, P, x) -@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, convert(Array{T}, x)) +_maybemutablecopy(x::StridedArray{T}, ::Type{T}) where {T} = x +_maybemutablecopy(x, T) = Array{T}(x) +@inline _plan_mul!(y::AbstractArray{T}, P::Plan{T}, x::AbstractArray) where T = mul!(y, P, _maybemutablecopy(x, T)) for op in (:ldiv, :lmul) @@ -309,7 +310,8 @@ function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,2,K,false,N}, _icheb2_rescale!(P.plan.region, y) end -*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} = mul!(similar(x), P, x) +*(P::IChebyshevTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} = + mul!(similar(x), P, _maybemutablecopy(x, T)) ichebyshevtransform!(x::AbstractArray, dims...; kwds...) = plan_ichebyshevtransform!(x, dims...; kwds...)*x ichebyshevtransform(x, dims...; kwds...) = plan_ichebyshevtransform(x, dims...; kwds...)*x diff --git a/test/chebyshevtests.jl b/test/chebyshevtests.jl index 614f9c6d..763ac3ce 100644 --- a/test/chebyshevtests.jl +++ b/test/chebyshevtests.jl @@ -451,6 +451,7 @@ using FastTransforms, Test @testset "immutable vectors" begin F = plan_chebyshevtransform([1.,2,3]) @test chebyshevtransform(1.0:3) == F * (1:3) + @test ichebyshevtransform(1.0:3) == ichebyshevtransform([1.0:3;]) end @testset "inv" begin