Skip to content

Commit

Permalink
fix inferrability
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Oct 22, 2023
1 parent f27740c commit 0d73994
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/chebyshevtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@ ChebyshevTransformPlan{T,kind}(plan::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where {T
ChebyshevTransformPlan{T,kind,K,inplace,N,R}(plan)

# jump through some hoops to make inferrable

_fftKtype(::Val{N}, _...) where N = NTuple{N,Int32}
_fftKtype(::Val{N}, ::AbstractVector) where N = Vector{Int32}

function plan_chebyshevtransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
ChebyshevTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
ChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
ChebyshevTransformPlan{T,1}(FFTW.plan_r2r!(x, FIRSTKIND, dims...; kws...))
end
Expand All @@ -34,7 +38,7 @@ end

function plan_chebyshevtransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
ChebyshevTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
ChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
ChebyshevTransformPlan{T,1}(FFTW.plan_r2r(x, FIRSTKIND, dims...; kws...))
end
Expand Down Expand Up @@ -246,7 +250,7 @@ inv(P::IChebyshevTransformPlan{T,1}) where {T} = ChebyshevTransformPlan{T,1}(inv

function plan_ichebyshevtransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
IChebyshevTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
IChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
IChebyshevTransformPlan{T,1}(FFTW.plan_r2r!(x, IFIRSTKIND, dims...; kws...))
end
Expand All @@ -258,7 +262,7 @@ end

function plan_ichebyshevtransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
IChebyshevTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
IChebyshevTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
IChebyshevTransformPlan{T,1}(FFTW.plan_r2r(x, IFIRSTKIND, dims...; kws...))
end
Expand Down Expand Up @@ -306,7 +310,7 @@ function mul!(y::AbstractArray{T,N}, P::IChebyshevTransformPlan{T,1,K,false,N},
size(y) == size(x) || throw(DimensionMismatch("output must match dimension"))
isempty(x) && return y

_icheb1_prescale!(P.plan.region, x) # Todo: don't mutate x
_icheb1_prescale!(P.plan.region, x) # TODO: don't mutate x
_plan_mul!(y, P.plan, x)
_icheb1_postscale!(P.plan.region, x)
ldiv!(2^length(P.plan.region), y)
Expand Down Expand Up @@ -388,7 +392,7 @@ ChebyshevUTransformPlan{T,kind}(plan::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where {

function plan_chebyshevutransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
ChebyshevUTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
ChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
ChebyshevUTransformPlan{T,1}(FFTW.plan_r2r!(x, UFIRSTKIND, dims...; kws...))
end
Expand All @@ -400,7 +404,7 @@ end

function plan_chebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
ChebyshevUTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
ChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
ChebyshevUTransformPlan{T,1}(FFTW.plan_r2r(x, UFIRSTKIND, dims...; kws...))
end
Expand Down Expand Up @@ -466,10 +470,10 @@ end

function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,1,K,false}, x::AbstractArray{T}) where {T,K}
size(y) == size(x) || throw(DimensionMismatch("output must match dimension"))
isempty(x) && return copyto!(y, x)
isempty(x) && return y
_chebu1_prescale!(P.plan.region, x) # Todo don't mutate x
_plan_mul!(y, P.plan, x)
_chebu1_postscale!(P.plan.region, y)
_chebu1_postscale!(P.plan.region, x)
for d in P.plan.region
size(y,d) == 1 && ldiv!(2, y) # fix doubling
end
Expand Down Expand Up @@ -535,7 +539,7 @@ end
function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,2,K,false}, x::AbstractArray{T}) where {T,K}
n = length(x)
n 1 && return copyto!(y, x)
_chebu2_prescale!(P.plan.region, x)
_chebu2_prescale!(P.plan.region, x) # TODO don't mutate x
_plan_mul!(y, P.plan, x)
_chebu2_postscale!(P.plan.region, x)
lmul!(one(T)/ (n+1), y)
Expand Down Expand Up @@ -571,7 +575,7 @@ IChebyshevUTransformPlan{T,kind}(F::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where {T,

function plan_ichebyshevutransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
IChebyshevUTransformPlan{T,1,Vector{Int32},true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
IChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),true,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
IChebyshevUTransformPlan{T,1}(FFTW.plan_r2r!(x, IUFIRSTKIND, dims...; kws...))
end
Expand All @@ -583,7 +587,7 @@ end

function plan_ichebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
IChebyshevUTransformPlan{T,1,Vector{Int32},false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
IChebyshevUTransformPlan{T,1,_fftKtype(Val{N}(), dims...),false,N,isempty(dims) ? NTuple{N,Int} : typeof(dims[1])}()
else
IChebyshevUTransformPlan{T,1}(FFTW.plan_r2r(x, IUFIRSTKIND, dims...; kws...))
end
Expand Down
1 change: 1 addition & 0 deletions test/chebyshevtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ using FastTransforms, Test
p_1 = chebyshevpoints(T, n)
f = exp.(p_1)
g = @inferred(chebyshevutransform(f))
@test f exp.(p_1)

= x -> [sin((k+1)*acos(x))/sin(acos(x)) for k=0:n-1]' * g
@test (0.1) exp(T(0.1))
Expand Down

0 comments on commit 0d73994

Please sign in to comment.