Skip to content

Commit

Permalink
Merge pull request #87 from tknopp/rfft
Browse files Browse the repository at this point in the history
Change FFT to Real to Complex FFT
  • Loading branch information
yuehhua authored Nov 29, 2022
2 parents 170ae0e + fed68b8 commit 3a2d04f
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
end

function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T, N},
M::NTuple{N, Int64}) where {T, N}
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
end
Expand Down
7 changes: 4 additions & 3 deletions src/Transform/fourier_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Base.ndims(::FourierTransform{N}) where {N} = N
Base.eltype(::Type{FourierTransform}) = ComplexF32

function transform(ft::FourierTransform, 𝐱::AbstractArray)
return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
return rfft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch]
end

function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray)
Expand All @@ -17,6 +17,7 @@ end

truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft)

function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray)
return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T, N},
M::NTuple{N, Int64}) where {T, N}
return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
end
2 changes: 1 addition & 1 deletion src/operator_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function operator_conv(m::OperatorConv, 𝐱::AbstractArray)
𝐱_padded = pad_modes(𝐱_applied_pattern,
(size(𝐱_transformed)[1:(end - 2)]...,
size(𝐱_applied_pattern)[(end - 1):end]...)) # [size(x)..., out_chs, batch] <- [modes..., out_chs, batch]
𝐱_inversed = inverse(m.transform, 𝐱_padded)
𝐱_inversed = inverse(m.transform, 𝐱_padded, size(𝐱))

return 𝐱_inversed
end
Expand Down
5 changes: 3 additions & 2 deletions test/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
@test ndims(t) == 3
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) ==
(3, 4, 5, ch, batch)

g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱)
@test size(g[1]) == (30, 40, 50, ch, batch)
end
14 changes: 11 additions & 3 deletions test/Transform/fourier_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@

ft = FourierTransform((3, 4, 5))

@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
@test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch)
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
@test size(inverse(ft,
NeuralOperators.pad_modes(truncate_modes(ft, transform(ft, 𝐱)),
size(transform(ft, 𝐱))),
size(𝐱))) == (30, 40, 50, ch, batch)

g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱)
g = Zygote.gradient(x -> sum(inverse(ft,
NeuralOperators.pad_modes(truncate_modes(ft,
transform(ft,
x)),
(16, 40, 50, ch, batch)),
(30, 40, 50, ch, batch))), 𝐱)
@test size(g[1]) == (30, 40, 50, ch, batch)
end
8 changes: 4 additions & 4 deletions test/operator_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ end
end

@testset "2D OperatorConv" begin
modes = (16, 16)
modes = (10, 10)
ch = 64 => 64

m = Chain(Dense(1, 64),
Expand All @@ -87,7 +87,7 @@ end
end

@testset "permuted 2D OperatorConv" begin
modes = (16, 16)
modes = (10, 10)
ch = 64 => 64

m = Chain(Conv((1, 1), 1 => 64),
Expand All @@ -104,7 +104,7 @@ end
end

@testset "2D OperatorKernel" begin
modes = (16, 16)
modes = (10, 10)
ch = 64 => 64

m = Chain(Dense(1, 64),
Expand All @@ -119,7 +119,7 @@ end
end

@testset "permuted 2D OperatorKernel" begin
modes = (16, 16)
modes = (10, 10)
ch = 64 => 64

m = Chain(Conv((1, 1), 1 => 64),
Expand Down

0 comments on commit 3a2d04f

Please sign in to comment.