Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Oct 22, 2023
1 parent 0d73994 commit 48a803b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
27 changes: 18 additions & 9 deletions src/chebyshevtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ ChebyshevTransformPlan{T,kind}(plan::FFTW.r2rFFTWPlan{T,K,inplace,N,R}) where {T

# jump through some hoops to make inferrable

_fftKtype(::Val{N}, _...) where N = NTuple{N,Int32}
_fftKtype(::Val{N}, ::AbstractVector) where N = Vector{Int32}
_fftKtype(::Val{N}, _...) where N = Vector{Int32}

function plan_chebyshevtransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
if isempty(x)
Expand Down Expand Up @@ -410,7 +409,13 @@ function plan_chebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws.
end
end
function plan_chebyshevutransform(x::AbstractArray{T,N}, ::Val{2}, dims...; kws...) where {T<:fftwNumber,N}
any((1),size(x)) && throw(ArgumentError("Array must contain at least 2 entries"))
if isempty(dims)
any((1), size(x)) && throw(ArgumentError("Array must contain at least 2 entries"))
else
for d in dims[1]
size(x,d) 1 && throw(ArgumentError("Array must contain at least 2 entries"))
end
end
ChebyshevUTransformPlan{T,2}(FFTW.plan_r2r(x, USECONDKIND, dims...; kws...))
end

Expand Down Expand Up @@ -530,19 +535,23 @@ end
end

function *(P::ChebyshevUTransformPlan{T,2,K,true,N}, x::AbstractArray{T,N}) where {T,K,N}
n = length(x)
n 1 && return x
sc = one(T)
for d in P.plan.region
sc *= one(T)/(size(x,d)+1)
end
_chebu2_prescale!(P.plan.region, x)
lmul!(one(T)/ (n+1), P.plan * x)
lmul!(sc, P.plan * x)
end

function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,2,K,false}, x::AbstractArray{T}) where {T,K}
n = length(x)
n 1 && return copyto!(y, x)
sc = one(T)
for d in P.plan.region
sc *= one(T)/(size(x,d)+1)
end
_chebu2_prescale!(P.plan.region, x) # TODO don't mutate x
_plan_mul!(y, P.plan, x)
_chebu2_postscale!(P.plan.region, x)
lmul!(one(T)/ (n+1), y)
lmul!(sc, y)
end

*(P::ChebyshevUTransformPlan{T,kind,K,false,N}, x::AbstractArray{T,N}) where {T,kind,K,N} =
Expand Down
9 changes: 4 additions & 5 deletions test/chebyshevtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,6 @@ using FastTransforms, Test
@test chebyshevtransform(ichebyshevtransform(X)) X
end

X = randn(1,1)
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))

@testset "chebyshevutransform" begin
@test @inferred(chebyshevutransform(X,1)) @inferred(chebyshevutransform!(copy(X),1)) hcat(chebyshevutransform.([X[:,k] for k=axes(X,2)])...)
@test chebyshevutransform(X,2) chebyshevutransform!(copy(X),2) hcat(chebyshevutransform.([X[k,:] for k=axes(X,1)])...)'
Expand All @@ -307,6 +302,10 @@ using FastTransforms, Test
@test @inferred(chebyshevutransform(X,Val(2))) @inferred(chebyshevutransform!(copy(X),Val(2))) chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2)
end

X = randn(1,1)
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
@test_throws ArgumentError ichebyshevtransform!(copy(X), Val(2))
end

@testset "tensor" begin
Expand Down

0 comments on commit 48a803b

Please sign in to comment.