From a4e887f6740b147bcf3a8f2149bbc233cf5f4867 Mon Sep 17 00:00:00 2001 From: Tobias Knopp Date: Mon, 22 Aug 2022 20:19:50 +0200 Subject: [PATCH 1/5] use rfft instead of fft --- src/Transform/chebyshev_transform.jl | 2 +- src/Transform/fourier_transform.jl | 6 +++--- src/operator_kernel.jl | 2 +- test/Transform/chebyshev_transform.jl | 4 ++-- test/Transform/fourier_transform.jl | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index 4ec783d7..6e7d6c6a 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -14,7 +14,7 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray) return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch] end -function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N} +function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M) where {N} normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))) return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch] end diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index ed040f96..de1c8303 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -7,7 +7,7 @@ end Base.ndims(::FourierTransform{N}) where {N} = N function transform(ft::FourierTransform, 𝐱::AbstractArray) - return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch] + return rfft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch] end function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray) @@ -16,6 +16,6 @@ end truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft) -function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray) - return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] +function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray, M) + return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] end diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index fa0a2eea..d17fb52c 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -92,7 +92,7 @@ function operator_conv(m::OperatorConv, 𝐱::AbstractArray) 𝐱_padded = pad_modes(𝐱_applied_pattern, (size(𝐱_transformed)[1:(end - 2)]..., size(𝐱_applied_pattern)[(end - 1):end]...)) # [size(x)..., out_chs, batch] <- [modes..., out_chs, batch] - 𝐱_inversed = inverse(m.transform, 𝐱_padded) + 𝐱_inversed = inverse(m.transform, 𝐱_padded, size(𝐱)) return 𝐱_inversed end diff --git a/test/Transform/chebyshev_transform.jl b/test/Transform/chebyshev_transform.jl index ef43ec42..49bd8cd5 100644 --- a/test/Transform/chebyshev_transform.jl +++ b/test/Transform/chebyshev_transform.jl @@ -8,8 +8,8 @@ @test ndims(t) == 3 @test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch) @test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch) + @test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch) - g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱) + g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) end diff --git a/test/Transform/fourier_transform.jl b/test/Transform/fourier_transform.jl index e26ef19d..2c35bdd9 100644 --- a/test/Transform/fourier_transform.jl +++ b/test/Transform/fourier_transform.jl @@ -7,8 +7,8 @@ @test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch) @test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch) + @test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch) - g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱) + g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)), size(𝐱))), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) end From 9ac12ad9e5572831c7e239098d4fefe002811219 Mon Sep 17 00:00:00 2001 From: Tobias Knopp Date: Sat, 27 Aug 2022 12:57:59 +0200 Subject: [PATCH 2/5] fix tests --- test/Transform/fourier_transform.jl | 10 ++++++---- test/operator_kernel.jl | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/Transform/fourier_transform.jl b/test/Transform/fourier_transform.jl index 2c35bdd9..d7c12d7e 100644 --- a/test/Transform/fourier_transform.jl +++ b/test/Transform/fourier_transform.jl @@ -5,10 +5,12 @@ ft = FourierTransform((3, 4, 5)) - @test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch) + @test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch) @test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch) + @test size(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, 𝐱)), size(transform(ft, 𝐱)) ), + size(𝐱))) == (30, 40, 50, ch, batch) - g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)), size(𝐱))), 𝐱) + g = Zygote.gradient(x -> sum(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, x)), + (16, 40, 50, ch, batch) ), (30, 40, 50, ch, batch) )), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) -end +end \ No newline at end of file diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index 2b02d2a3..30d030f1 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -71,7 +71,7 @@ end end @testset "2D OperatorConv" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Dense(1, 64), @@ -87,7 +87,7 @@ end end @testset "permuted 2D OperatorConv" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Conv((1, 1), 1 => 64), @@ -104,7 +104,7 @@ end end @testset "2D OperatorKernel" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Dense(1, 64), @@ -119,7 +119,7 @@ end end @testset "permuted 2D OperatorKernel" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Conv((1, 1), 1 => 64), From ff8c13f58896b53f35c44b3127df325e45733524 Mon Sep 17 00:00:00 2001 From: Tobias Knopp Date: Mon, 28 Nov 2022 21:59:55 +0100 Subject: [PATCH 3/5] add type declaration --- src/Transform/chebyshev_transform.jl | 2 +- src/Transform/fourier_transform.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index 6e7d6c6a..d10b5e4c 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -14,7 +14,7 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray) return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch] end -function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M) where {N} +function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M::NTuple{N, Int64}) where {N} normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))) return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch] end diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index de1c8303..b59ac9ac 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -16,6 +16,6 @@ end truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft) -function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray, M) +function inverse(ft::FourierTransform{N}, 𝐱_fft::AbstractArray, M::NTuple{N, Int64}) where {N} return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] end From cdea34ef521fe3e5752033379ed2690a7caf7b42 Mon Sep 17 00:00:00 2001 From: Tobias Knopp Date: Tue, 29 Nov 2022 10:14:01 +0100 Subject: [PATCH 4/5] fix signature --- src/Transform/chebyshev_transform.jl | 2 +- src/Transform/fourier_transform.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index d10b5e4c..dca54987 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -14,7 +14,7 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray) return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch] end -function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M::NTuple{N, Int64}) where {N} +function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N} normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))) return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch] end diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index b59ac9ac..201472f9 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -16,6 +16,6 @@ end truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft) -function inverse(ft::FourierTransform{N}, 𝐱_fft::AbstractArray, M::NTuple{N, Int64}) where {N} +function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N} return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] end From fed68b8c2b1798acb49e79744032aef0d5acdba8 Mon Sep 17 00:00:00 2001 From: Tobias Knopp Date: Tue, 29 Nov 2022 11:08:19 +0100 Subject: [PATCH 5/5] fix formatting --- src/Transform/chebyshev_transform.jl | 3 ++- src/Transform/fourier_transform.jl | 3 ++- test/Transform/chebyshev_transform.jl | 3 ++- test/Transform/fourier_transform.jl | 16 +++++++++++----- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index dca54987..ee4efbb8 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -14,7 +14,8 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray) return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch] end -function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N} +function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T, N}, + M::NTuple{N, Int64}) where {T, N} normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))) return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch] end diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index 201472f9..f994a36f 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -16,6 +16,7 @@ end truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft) -function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N} +function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T, N}, + M::NTuple{N, Int64}) where {T, N} return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] end diff --git a/test/Transform/chebyshev_transform.jl b/test/Transform/chebyshev_transform.jl index 49bd8cd5..b15ff9df 100644 --- a/test/Transform/chebyshev_transform.jl +++ b/test/Transform/chebyshev_transform.jl @@ -8,7 +8,8 @@ @test ndims(t) == 3 @test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch) @test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch) + @test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == + (3, 4, 5, ch, batch) g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) diff --git a/test/Transform/fourier_transform.jl b/test/Transform/fourier_transform.jl index d7c12d7e..4583906a 100644 --- a/test/Transform/fourier_transform.jl +++ b/test/Transform/fourier_transform.jl @@ -7,10 +7,16 @@ @test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch) @test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, 𝐱)), size(transform(ft, 𝐱)) ), - size(𝐱))) == (30, 40, 50, ch, batch) + @test size(inverse(ft, + NeuralOperators.pad_modes(truncate_modes(ft, transform(ft, 𝐱)), + size(transform(ft, 𝐱))), + size(𝐱))) == (30, 40, 50, ch, batch) - g = Zygote.gradient(x -> sum(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, x)), - (16, 40, 50, ch, batch) ), (30, 40, 50, ch, batch) )), 𝐱) + g = Zygote.gradient(x -> sum(inverse(ft, + NeuralOperators.pad_modes(truncate_modes(ft, + transform(ft, + x)), + (16, 40, 50, ch, batch)), + (30, 40, 50, ch, batch))), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) -end \ No newline at end of file +end