Skip to content

Commit

Permalink
add Chebyshev transform
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Jul 11, 2022
1 parent ae5c22e commit d49f1ac
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/Transform/Transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ export
abstract type AbstractTransform end

include("fourier_transform.jl")
include("chebyshev_transform.jl")

const truncate_modes = low_pass
23 changes: 23 additions & 0 deletions src/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
export ChebyshevTransform

struct ChebyshevTransform{N, S}<:AbstractTransform
modes::NTuple{N, S} # N == ndims(x)
end

Base.ndims(::ChebyshevTransform{N}) where {N} = N

function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
return FFTW.r2r(𝐱, FFTW.REDFT00, 1:N) # [size(x)..., in_chs, batch]
end

function low_pass(t::ChebyshevTransform, 𝐱̂::AbstractArray)
return view(𝐱̂, map(d->1:d, t.modes)..., :, :) # [ft.modes..., in_chs, batch]
end

function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
return FFTW.r2r(
𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))),
FFTW.REDFT00,
1:N,
) # [size(x)..., in_chs, batch]
end
5 changes: 4 additions & 1 deletion test/Transform/Transform.jl
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
@testset "Transform" begin include("fourier_transform.jl") end
@testset "Transform" begin
include("fourier_transform.jl")
include("chebyshev_transform.jl")
end
11 changes: 11 additions & 0 deletions test/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@testset "Chebyshev transform" begin
ch = 6
batch = 7
𝐱 = rand(30, 40, 50, ch, batch)

t = ChebyshevTransform((3, 4, 5))

@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
end
12 changes: 7 additions & 5 deletions test/Transform/fourier_transform.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
@testset "fourier transform" begin
𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7
@testset "Fourier transform" begin
ch = 6
batch = 7
𝐱 = rand(30, 40, 50, ch, batch)

ft = FourierTransform((3, 4, 5))

@test size(transform(ft, 𝐱)) == (30, 40, 50, 6, 7)
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, 6, 7)
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, 6, 7)
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
end

0 comments on commit d49f1ac

Please sign in to comment.