Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Nov 10, 2023
1 parent c88516d commit 097d881
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 58 deletions.
80 changes: 24 additions & 56 deletions src/chebyshevtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,19 +439,11 @@ end
y
end

@inline function _chebu1_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
m,n = size(x,1),size(x,2)
if d == 1
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= sinpi(one(T)/(2m) + (k-one(T))/m)/m
end
else
@assert d == 2
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= sinpi(one(T)/(2n) + (j-one(T))/n)/n
end
end
x
@inline function _chebu1_postscale!(d::Number, X::AbstractArray{T,N}) where {T,N}
= PermutedDimsArray(X, _permfirst(d, N))
m = size(X̃,1)
X̃ .=./ (sinpi.(one(T)/(2m) .+ ((1:m) .- one(T))/m) ./ m)
X
end

@inline function _chebu1_postscale!(d, y::AbstractArray)
Expand Down Expand Up @@ -479,21 +471,13 @@ function mul!(y::AbstractArray{T}, P::ChebyshevUTransformPlan{T,1,K,false}, x::A
y
end

@inline function _chebu2_prescale!(d::Number, x::AbstractVecOrMat{T}) where T
m,n = size(x,1),size(x,2)
if d == 1
c = one(T)/ (m+1)
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] *= sinpi(k*c)
end
else
@assert d == 2
c = one(T)/ (n+1)
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] *= sinpi(j*c)
end
end
x

@inline function _chebu2_prescale!(d::Number, X::AbstractArray{T,N}) where {T,N}
= PermutedDimsArray(X, _permfirst(d, N))
m = size(X̃,1)
c = one(T)/ (m+1)
X̃ .= sinpi.((1:m) .* c) .*
X
end

@inline function _chebu2_prescale!(d, y::AbstractArray)
Expand All @@ -504,21 +488,12 @@ end
end


@inline function _chebu2_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
m,n = size(x,1),size(x,2)
if d == 1
c = one(T)/ (m+1)
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= sinpi(k*c)
end
else
@assert d == 2
c = one(T)/ (n+1)
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= sinpi(j*c)
end
end
x
@inline function _chebu2_postscale!(d::Number, X::AbstractArray{T,N}) where {T,N}
= PermutedDimsArray(X, _permfirst(d, N))
m = size(X̃,1)
c = one(T)/ (m+1)
X̃ .=./ sinpi.((1:m) .* c)
X
end

@inline function _chebu2_postscale!(d, y::AbstractArray)
Expand Down Expand Up @@ -612,21 +587,14 @@ inv(P::IChebyshevUTransformPlan{T,2}) where {T} = ChebyshevUTransformPlan{T,2}(P
inv(P::ChebyshevUTransformPlan{T,1}) where {T} = IChebyshevUTransformPlan{T,1}(inv(P.plan).p)
inv(P::IChebyshevUTransformPlan{T,1}) where {T} = ChebyshevUTransformPlan{T,1}(inv(P.plan).p)

@inline function _ichebu1_postscale!(d::Number, x::AbstractVecOrMat{T}) where T
m,n = size(x,1),size(x,2)
if d == 1
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= 2sinpi(one(T)/(2m) + (k-one(T))/m)
end
else
@assert d == 2
for j = 1:n, k = 1:m # sqrt(1-x_j^2) weight
x[k,j] /= 2sinpi(one(T)/(2n) + (j-one(T))/n)
end
end
x
@inline function _ichebu1_postscale!(d::Number, X::AbstractArray{T,N}) where {T,N}
= PermutedDimsArray(X, _permfirst(d, N))
m = size(X̃,1)
X̃ .=./ (2 .* sinpi.(one(T)/(2m) .+ ((1:m) .- one(T))/m))
X
end


@inline function _ichebu1_postscale!(d, y::AbstractArray)
for k in d
_ichebu1_postscale!(k, y)
Expand Down
45 changes: 43 additions & 2 deletions test/chebyshevtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ using FastTransforms, Test
gcopy = copy(g)
P = @inferred(plan_chebyshevutransform(f))
@test P*f g
@test f == fcopy
@test f fcopy
@test_throws ArgumentError P * T[1,2]
P = @inferred(plan_chebyshevutransform(f, 1:1))
@test P*f g
@test f == fcopy
@test f fcopy
@test_throws ArgumentError P * T[1,2]

P = @inferred(plan_chebyshevutransform!(f))
Expand Down Expand Up @@ -364,6 +364,47 @@ using FastTransforms, Test
@test ichebyshevtransform(chebyshevtransform(X)) X
@test chebyshevtransform(ichebyshevtransform(X)) X
end

@testset "chebyshevutransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j]) end
@test @inferred(chebyshevutransform(X,1)) @inferred(chebyshevutransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j]) end
@test chebyshevutransform(X,2) chebyshevutransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:]) end
@test chebyshevutransform(X,3) chebyshevutransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = chebyshevutransform(X[:,k,j],Val(2)) end
@test @inferred(chebyshevutransform(X,Val(2),1)) @inferred(chebyshevutransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = chebyshevutransform(X[k,:,j],Val(2)) end
@test chebyshevutransform(X,Val(2),2) chebyshevutransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = chebyshevutransform(X[k,j,:],Val(2)) end
@test chebyshevutransform(X,Val(2),3) chebyshevutransform!(copy(X),Val(2),3)

@test @inferred(chebyshevutransform(X)) @inferred(chebyshevutransform!(copy(X))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,1),2),3)
@test @inferred(chebyshevutransform(X,Val(2))) @inferred(chebyshevutransform!(copy(X),Val(2))) chebyshevutransform(chebyshevutransform(chebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)
end

@testset "ichebyshevutransform" begin
for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j]) end
@test @inferred(ichebyshevutransform(X,1)) @inferred(ichebyshevutransform!(copy(X),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j]) end
@test ichebyshevutransform(X,2) ichebyshevutransform!(copy(X),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:]) end
@test ichebyshevutransform(X,3) ichebyshevutransform!(copy(X),3)

for k = axes(X,2), j = axes(X,3) X̃[:,k,j] = ichebyshevutransform(X[:,k,j],Val(2)) end
@test @inferred(ichebyshevutransform(X,Val(2),1)) @inferred(ichebyshevutransform!(copy(X),Val(2),1))
for k = axes(X,1), j = axes(X,3) X̃[k,:,j] = ichebyshevutransform(X[k,:,j],Val(2)) end
@test ichebyshevutransform(X,Val(2),2) ichebyshevutransform!(copy(X),Val(2),2)
for k = axes(X,1), j = axes(X,2) X̃[k,j,:] = ichebyshevutransform(X[k,j,:],Val(2)) end
@test ichebyshevutransform(X,Val(2),3) ichebyshevutransform!(copy(X),Val(2),3)

@test @inferred(ichebyshevutransform(X)) @inferred(ichebyshevutransform!(copy(X))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,1),2),3)
@test @inferred(ichebyshevutransform(X,Val(2))) @inferred(ichebyshevutransform!(copy(X),Val(2))) ichebyshevutransform(ichebyshevutransform(ichebyshevutransform(X,Val(2),1),Val(2),2),Val(2),3)

@test ichebyshevutransform(chebyshevutransform(X)) X
@test chebyshevutransform(ichebyshevutransform(X)) X
end

X = randn(1,1,1)
@test chebyshevtransform!(copy(X), Val(1)) == ichebyshevtransform!(copy(X), Val(1)) == X
Expand Down

0 comments on commit 097d881

Please sign in to comment.