From 2e2e2ef4847e2184af256fa4fc9e3f7ddc74ed2a Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sun, 22 Oct 2023 16:52:43 +0100 Subject: [PATCH] matrix ichebyshevu --- src/chebyshevtransform.jl | 68 +++++++++++++++++++++++---------------- test/chebyshevtests.jl | 13 ++++++++ 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/src/chebyshevtransform.jl b/src/chebyshevtransform.jl index ea0b4cd1..1d237aa9 100644 --- a/src/chebyshevtransform.jl +++ b/src/chebyshevtransform.jl @@ -591,7 +591,7 @@ function plan_ichebyshevutransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kw end function plan_ichebyshevutransform!(x::AbstractArray{T,N}, ::Val{2}, dims...; kws...) where {T<:fftwNumber,N} any(≤(1),size(x)) && throw(ArgumentError("Array must contain at least 2 entries")) - IChebyshevUTransformPlan{T,2}(FFTW.plan_r2r!(x, USECONDKIND)) + IChebyshevUTransformPlan{T,2}(FFTW.plan_r2r!(x, USECONDKIND, dims...)) end function plan_ichebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N} @@ -618,42 +618,55 @@ inv(P::IChebyshevUTransformPlan{T,2}) where {T} = ChebyshevUTransformPlan{T,2}(P inv(P::ChebyshevUTransformPlan{T,1}) where {T} = IChebyshevUTransformPlan{T,1}(inv(P.plan).p) inv(P::IChebyshevUTransformPlan{T,1}) where {T} = ChebyshevUTransformPlan{T,1}(inv(P.plan).p) - -function _ichebyu1_postscale!(_, x::AbstractVector{T}) where T - n = length(x) - @inbounds for k=1:n # sqrt(1-x_j^2) weight - x[k] /= 2sinpi(one(T)/(2n) + (k-one(T))/n) +@inline function _ichebu1_postscale!(d::Number, x::AbstractVecOrMat{T}) where T + m,n = size(x,1),size(x,2) + if d == 1 + for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight + x[k,j] /= 2sinpi(one(T)/(2m) + (k-one(T))/m) + end + else + @assert d == 2 + for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight + x[k,j] /= 2sinpi(one(T)/(2n) + (j-one(T))/n) + end end x end -function *(P::IChebyshevUTransformPlan{T,1,K,true}, x::AbstractVector{T}) where {T<:fftwNumber,K} - n = length(x) - n ≤ 1 && return x - x = P.plan * x - _ichebyu1_postscale!(P.plan.region, x) +@inline function _ichebu1_postscale!(d, y::AbstractArray) + for k in d + _ichebu1_postscale!(k, y) + end + y end -function mul!(y::AbstractVector{T}, P::IChebyshevUTransformPlan{T,1,K,false}, x::AbstractVector{T}) where {T<:fftwNumber,K} - n = length(x) - length(y) == n || throw(DimensionMismatch("output must match dimension")) - n ≤ 1 && return x +function *(P::IChebyshevUTransformPlan{T,1,K,true}, x::AbstractArray{T}) where {T<:fftwNumber,K} + length(x) ≤ 1 && return x + x = P.plan * x + _ichebu1_postscale!(P.plan.region, x) +end +function mul!(y::AbstractArray{T}, P::IChebyshevUTransformPlan{T,1,K,false}, x::AbstractArray{T}) where {T<:fftwNumber,K} + size(y) == size(x) || throw(DimensionMismatch("output must match dimension")) + isempty(x) && return y _plan_mul!(y, P.plan, x) - _ichebyu1_postscale!(P.plan.region, y) + _ichebu1_postscale!(P.plan.region, y) end -function _ichebu2_rescale!(_, x::AbstractVector{T}) where T - n = length(x) - c = one(T)/ (n+1) - for k=1:n # sqrt(1-x_j^2) weight - x[k] /= sinpi(k*c) - end +function _ichebu2_rescale!(d::Number, x::AbstractArray{T}) where T + _chebu2_prescale!(d, x) ldiv!(2, x) x end -function *(P::IChebyshevUTransformPlan{T,2,K,true}, x::AbstractVector{T}) where {T<:fftwNumber,K} +@inline function _ichebu2_rescale!(d, y::AbstractArray) + for k in d + _ichebu2_rescale!(k, y) + end + y +end + +function *(P::IChebyshevUTransformPlan{T,2,K,true}, x::AbstractArray{T}) where {T<:fftwNumber,K} n = length(x) n ≤ 1 && return x @@ -661,16 +674,15 @@ function *(P::IChebyshevUTransformPlan{T,2,K,true}, x::AbstractVector{T}) where _ichebu2_rescale!(P.plan.region, x) end -function mul!(y::AbstractVector{T}, P::IChebyshevUTransformPlan{T,2,K,false}, x::AbstractVector{T}) where {T<:fftwNumber,K} - n = length(x) - length(y) == n || throw(DimensionMismatch("output must match dimension")) - n ≤ 1 && return x +function mul!(y::AbstractArray{T}, P::IChebyshevUTransformPlan{T,2,K,false}, x::AbstractArray{T}) where {T<:fftwNumber,K} + size(y) == size(x) || throw(DimensionMismatch("output must match dimension")) + length(x) ≤ 1 && return x _plan_mul!(y, P.plan, x) _ichebu2_rescale!(P.plan.region, y) end -ichebyshevutransform!(x::AbstractVector{T}, dims...; kwds...) where {T<:fftwNumber} = +ichebyshevutransform!(x::AbstractArray{T}, dims...; kwds...) where {T<:fftwNumber} = plan_ichebyshevutransform!(x, dims...; kwds...)*x ichebyshevutransform(x, dims...; kwds...) = plan_ichebyshevutransform(x, dims...; kwds...)*x diff --git a/test/chebyshevtests.jl b/test/chebyshevtests.jl index f6002fe9..48ecc115 100644 --- a/test/chebyshevtests.jl +++ b/test/chebyshevtests.jl @@ -302,6 +302,19 @@ using FastTransforms, Test @test @inferred(chebyshevutransform(X,Val(2))) ≈ @inferred(chebyshevutransform!(copy(X),Val(2))) ≈ chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2) end + @testset "ichebyshevutransform" begin + @test @inferred(ichebyshevutransform(X,1)) ≈ @inferred(ichebyshevutransform!(copy(X),1)) ≈ hcat(ichebyshevutransform.([X[:,k] for k=axes(X,2)])...) + @test ichebyshevutransform(X,2) ≈ ichebyshevutransform!(copy(X),2) ≈ hcat(ichebyshevutransform.([X[k,:] for k=axes(X,1)])...)' + @test @inferred(ichebyshevutransform(X,Val(2),1)) ≈ @inferred(ichebyshevutransform!(copy(X),Val(2),1)) ≈ hcat(ichebyshevutransform.([X[:,k] for k=axes(X,2)],Val(2))...) + @test ichebyshevutransform(X,Val(2),2) ≈ ichebyshevutransform!(copy(X),Val(2),2) ≈ hcat(ichebyshevutransform.([X[k,:] for k=axes(X,1)],Val(2))...)' + + @test @inferred(ichebyshevutransform(X)) ≈ @inferred(ichebyshevutransform!(copy(X))) ≈ ichebyshevutransform(ichebyshevutransform(X,1),2) + @test @inferred(ichebyshevutransform(X,Val(2))) ≈ @inferred(ichebyshevutransform!(copy(X),Val(2))) ≈ ichebyshevutransform(ichebyshevutransform(X,Val(2),1),Val(2),2) + + @test ichebyshevutransform(chebyshevutransform(X)) ≈ X + @test chebyshevutransform(ichebyshevutransform(X)) ≈ X + end + X = randn(1,1) @test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X @test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))