From 0d739948f150cb4a4469667a032db03259643d71 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sun, 22 Oct 2023 14:43:45 +0100 Subject: [PATCH] fix inferrability --- src/chebyshevtransform.jl | 28 ++++++++++++++++------------ test/chebyshevtests.jl | 1 + 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/chebyshevtransform.jl b/src/chebyshevtransform.jl index 44500ad4..62649ce4 100644 --- a/src/chebyshevtransform.jl +++ b/src/chebyshevtransform.jl @@ -19,9 +19,13 @@ ChebyshevTransformPlan{T,kind}(plan::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where {T ChebyshevTransformPlan{T,kind,K,inplace,N,R}(plan) # jump through some hoops to make inferrable + +_fftKtype(::Val{N}, _...) where N = NTuple{N,Int32} +_fftKtype(::Val{N}, ::AbstractVector) where N = Vector{Int32} + function plan_chebyshevtransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - ChebyshevTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + ChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else ChebyshevTransformPlan{T,1}(FFTW.plan_r2r!(x, FIRSTKIND, dims...; kws...)) end @@ -34,7 +38,7 @@ end function plan_chebyshevtransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - ChebyshevTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + ChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else ChebyshevTransformPlan{T,1}(FFTW.plan_r2r(x, FIRSTKIND, dims...; kws...)) end @@ -246,7 +250,7 @@ inv(P::IChebyshevTransformPlan{T,1}) where {T} = ChebyshevTransformPlan{T,1}(inv function plan_ichebyshevtransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - IChebyshevTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + IChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else IChebyshevTransformPlan{T,1}(FFTW.plan_r2r!(x, IFIRSTKIND, dims...; kws...)) end @@ -258,7 +262,7 @@ end function plan_ichebyshevtransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - IChebyshevTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + IChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else IChebyshevTransformPlan{T,1}(FFTW.plan_r2r(x, IFIRSTKIND, dims...; kws...)) end @@ -306,7 +310,7 @@ function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,1,K,false,N}, size(y) == size(x) || throw(DimensionMismatch("output must match dimension")) isempty(x) && return y - _icheb1_prescale!(P.plan.region, x) # Todo: don't mutate x + _icheb1_prescale!(P.plan.region, x) # TODO: don't mutate x _plan_mul!(y, P.plan, x) _icheb1_postscale!(P.plan.region, x) ldiv!(2^length(P.plan.region), y) @@ -388,7 +392,7 @@ ChebyshevUTransformPlan{T,kind}(plan::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where { function plan_chebyshevutransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - ChebyshevUTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + ChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else ChebyshevUTransformPlan{T,1}(FFTW.plan_r2r!(x, UFIRSTKIND, dims...; kws...)) end @@ -400,7 +404,7 @@ end function plan_chebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - ChebyshevUTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + ChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else ChebyshevUTransformPlan{T,1}(FFTW.plan_r2r(x, UFIRSTKIND, dims...; kws...)) end @@ -466,10 +470,10 @@ end function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,1,K,false}, x::AbstractArray{T}) where {T,K} size(y) == size(x) || throw(DimensionMismatch("output must match dimension")) - isempty(x) && return copyto!(y, x) + isempty(x) && return y _chebu1_prescale!(P.plan.region, x) # Todo don't mutate x _plan_mul!(y, P.plan, x) - _chebu1_postscale!(P.plan.region, y) + _chebu1_postscale!(P.plan.region, x) for d in P.plan.region size(y,d) == 1 && ldiv!(2, y) # fix doubling end @@ -535,7 +539,7 @@ end function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,2,K,false}, x::AbstractArray{T}) where {T,K} n = length(x) n ≤ 1 && return copyto!(y, x) - _chebu2_prescale!(P.plan.region, x) + _chebu2_prescale!(P.plan.region, x) # TODO don't mutate x _plan_mul!(y, P.plan, x) _chebu2_postscale!(P.plan.region, x) lmul!(one(T)/ (n+1), y) @@ -571,7 +575,7 @@ IChebyshevUTransformPlan{T,kind}(F::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where {T, function plan_ichebyshevutransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - IChebyshevUTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + IChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else IChebyshevUTransformPlan{T,1}(FFTW.plan_r2r!(x, IUFIRSTKIND, dims...; kws...)) end @@ -583,7 +587,7 @@ end function plan_ichebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} if isempty(x) - IChebyshevUTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() + IChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}() else IChebyshevUTransformPlan{T,1}(FFTW.plan_r2r(x, IUFIRSTKIND, dims...; kws...)) end diff --git a/test/chebyshevtests.jl b/test/chebyshevtests.jl index 3eb7fa1c..27546a2f 100644 --- a/test/chebyshevtests.jl +++ b/test/chebyshevtests.jl @@ -154,6 +154,7 @@ using FastTransforms, Test p_1 = chebyshevpoints(T, n) f = exp.(p_1) g = @inferred(chebyshevutransform(f)) + @test f ≈ exp.(p_1) f̃ = x -> [sin((k+1)*acos(x))/sin(acos(x)) for k=0:n-1]' * g @test f̃(0.1) ≈ exp(T(0.1))