Skip to content

Commit

Permalink
matrix ichebyshevu
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Oct 22, 2023
1 parent 48a803b commit 2e2e2ef
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
68 changes: 40 additions & 28 deletions src/chebyshevtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ function plan_ichebyshevutransform!(x::AbstractArray{T,N}, ::Val{1}, dims...; kw
end
function plan_ichebyshevutransform!(x::AbstractArray{T,N}, ::Val{2}, dims...; kws...) where {T<:fftwNumber,N}
any((1),size(x)) && throw(ArgumentError("Array must contain at least 2 entries"))
IChebyshevUTransformPlan{T,2}(FFTW.plan_r2r!(x, USECONDKIND))
IChebyshevUTransformPlan{T,2}(FFTW.plan_r2r!(x, USECONDKIND, dims...))
end

function plan_ichebyshevutransform(x::AbstractArray{T,N}, ::Val{1}, dims...; kws...) where {T<:fftwNumber,N}
Expand All @@ -618,59 +618,71 @@ inv(P::IChebyshevUTransformPlan{T,2}) where {T} = ChebyshevUTransformPlan{T,2}(P
inv(P::ChebyshevUTransformPlan{T,1}) where {T} = IChebyshevUTransformPlan{T,1}(inv(P.plan).p)
inv(P::IChebyshevUTransformPlan{T,1}) where {T} = ChebyshevUTransformPlan{T,1}(inv(P.plan).p)


function _ichebyu1_postscale!(_, x::AbstractVector{T}) where T
n = length(x)
@inbounds for k=1:n # sqrt(1-x_j^2) weight
x[k] /= 2sinpi(one(T)/(2n) + (k-one(T))/n)
@inline function _ichebu1_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
m,n = size(x,1),size(x,2)
if d == 1
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= 2sinpi(one(T)/(2m) + (k-one(T))/m)
end
else
@assert d == 2
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= 2sinpi(one(T)/(2n) + (j-one(T))/n)
end
end
x
end
function *(P::IChebyshevUTransformPlan{T,1,K,true}, x::AbstractVector{T}) where {T<:fftwNumber,K}
n = length(x)
n 1 && return x

x = P.plan * x
_ichebyu1_postscale!(P.plan.region, x)
@inline function _ichebu1_postscale!(d, y::AbstractArray)
for k in d
_ichebu1_postscale!(k, y)
end
y
end

function mul!(y::AbstractVector{T}, P::IChebyshevUTransformPlan{T,1,K,false}, x::AbstractVector{T}) where {T<:fftwNumber,K}
n = length(x)
length(y) == n || throw(DimensionMismatch("output must match dimension"))
n 1 && return x
function *(P::IChebyshevUTransformPlan{T,1,K,true}, x::AbstractArray{T}) where {T<:fftwNumber,K}
length(x) 1 && return x
x = P.plan * x
_ichebu1_postscale!(P.plan.region, x)
end

function mul!(y::AbstractArray{T}, P::IChebyshevUTransformPlan{T,1,K,false}, x::AbstractArray{T}) where {T<:fftwNumber,K}
size(y) == size(x) || throw(DimensionMismatch("output must match dimension"))
isempty(x) && return y
_plan_mul!(y, P.plan, x)
_ichebyu1_postscale!(P.plan.region, y)
_ichebu1_postscale!(P.plan.region, y)
end

function _ichebu2_rescale!(_, x::AbstractVector{T}) where T
n = length(x)
c = one(T)/ (n+1)
for k=1:n # sqrt(1-x_j^2) weight
x[k] /= sinpi(k*c)
end
function _ichebu2_rescale!(d::Number, x::AbstractArray{T}) where T
_chebu2_prescale!(d, x)
ldiv!(2, x)
x
end

function *(P::IChebyshevUTransformPlan{T,2,K,true}, x::AbstractVector{T}) where {T<:fftwNumber,K}
@inline function _ichebu2_rescale!(d, y::AbstractArray)
for k in d
_ichebu2_rescale!(k, y)
end
y
end

function *(P::IChebyshevUTransformPlan{T,2,K,true}, x::AbstractArray{T}) where {T<:fftwNumber,K}
n = length(x)
n 1 && return x

x = P.plan * x
_ichebu2_rescale!(P.plan.region, x)
end

function mul!(y::AbstractVector{T}, P::IChebyshevUTransformPlan{T,2,K,false}, x::AbstractVector{T}) where {T<:fftwNumber,K}
n = length(x)
length(y) == n || throw(DimensionMismatch("output must match dimension"))
n 1 && return x
function mul!(y::AbstractArray{T}, P::IChebyshevUTransformPlan{T,2,K,false}, x::AbstractArray{T}) where {T<:fftwNumber,K}
size(y) == size(x) || throw(DimensionMismatch("output must match dimension"))
length(x) 1 && return x

_plan_mul!(y, P.plan, x)
_ichebu2_rescale!(P.plan.region, y)
end

ichebyshevutransform!(x::AbstractVector{T}, dims...; kwds...) where {T<:fftwNumber} =
ichebyshevutransform!(x::AbstractArray{T}, dims...; kwds...) where {T<:fftwNumber} =
plan_ichebyshevutransform!(x, dims...; kwds...)*x

ichebyshevutransform(x, dims...; kwds...) = plan_ichebyshevutransform(x, dims...; kwds...)*x
Expand Down
13 changes: 13 additions & 0 deletions test/chebyshevtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,19 @@ using FastTransforms, Test
@test @inferred(chebyshevutransform(X,Val(2))) @inferred(chebyshevutransform!(copy(X),Val(2))) chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2)
end

@testset "ichebyshevutransform" begin
@test @inferred(ichebyshevutransform(X,1)) @inferred(ichebyshevutransform!(copy(X),1)) hcat(ichebyshevutransform.([X[:,k] for k=axes(X,2)])...)
@test ichebyshevutransform(X,2) ichebyshevutransform!(copy(X),2) hcat(ichebyshevutransform.([X[k,:] for k=axes(X,1)])...)'
@test @inferred(ichebyshevutransform(X,Val(2),1)) @inferred(ichebyshevutransform!(copy(X),Val(2),1)) hcat(ichebyshevutransform.([X[:,k] for k=axes(X,2)],Val(2))...)
@test ichebyshevutransform(X,Val(2),2) ichebyshevutransform!(copy(X),Val(2),2) hcat(ichebyshevutransform.([X[k,:] for k=axes(X,1)],Val(2))...)'

@test @inferred(ichebyshevutransform(X)) @inferred(ichebyshevutransform!(copy(X))) ichebyshevutransform(ichebyshevutransform(X,1),2)
@test @inferred(ichebyshevutransform(X,Val(2))) @inferred(ichebyshevutransform!(copy(X),Val(2))) ichebyshevutransform(ichebyshevutransform(X,Val(2),1),Val(2),2)

@test ichebyshevutransform(chebyshevutransform(X)) X
@test chebyshevutransform(ichebyshevutransform(X)) X
end

X = randn(1,1)
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
@test_throws ArgumentError chebyshevtransform!(copy(X), Val(2))
Expand Down

0 comments on commit 2e2e2ef

Please sign in to comment.