diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index a6c0191d..e83a9fff 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -15,7 +15,8 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray) return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch] end -function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N} +function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T, N}, + M::NTuple{N, Int64}) where {T, N} normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))) return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch] end diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index 09e581b3..bab74918 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -8,7 +8,7 @@ Base.ndims(::FourierTransform{N}) where {N} = N Base.eltype(::Type{FourierTransform}) = ComplexF32 function transform(ft::FourierTransform, 𝐱::AbstractArray) - return fft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch] + return rfft(Zygote.hook(real, 𝐱), 1:ndims(ft)) # [size(x)..., in_chs, batch] end function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray) @@ -17,6 +17,7 @@ end truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft) -function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray) - return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] +function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T, N}, + M::NTuple{N, Int64}) where {T, N} + return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] end diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index d131ad34..453ab541 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -92,7 +92,7 @@ function operator_conv(m::OperatorConv, 𝐱::AbstractArray) 𝐱_padded = pad_modes(𝐱_applied_pattern, (size(𝐱_transformed)[1:(end - 2)]..., size(𝐱_applied_pattern)[(end - 1):end]...)) # [size(x)..., out_chs, batch] <- [modes..., out_chs, batch] - 𝐱_inversed = inverse(m.transform, 𝐱_padded) + 𝐱_inversed = inverse(m.transform, 𝐱_padded, size(𝐱)) return 𝐱_inversed end diff --git a/test/Transform/chebyshev_transform.jl b/test/Transform/chebyshev_transform.jl index ef43ec42..b15ff9df 100644 --- a/test/Transform/chebyshev_transform.jl +++ b/test/Transform/chebyshev_transform.jl @@ -8,8 +8,9 @@ @test ndims(t) == 3 @test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch) @test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch) + @test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == + (3, 4, 5, ch, batch) - g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱) + g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) end diff --git a/test/Transform/fourier_transform.jl b/test/Transform/fourier_transform.jl index e26ef19d..4583906a 100644 --- a/test/Transform/fourier_transform.jl +++ b/test/Transform/fourier_transform.jl @@ -5,10 +5,18 @@ ft = FourierTransform((3, 4, 5)) - @test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch) + @test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch) @test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch) - @test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch) + @test size(inverse(ft, + NeuralOperators.pad_modes(truncate_modes(ft, transform(ft, 𝐱)), + size(transform(ft, 𝐱))), + size(𝐱))) == (30, 40, 50, ch, batch) - g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱) + g = Zygote.gradient(x -> sum(inverse(ft, + NeuralOperators.pad_modes(truncate_modes(ft, + transform(ft, + x)), + (16, 40, 50, ch, batch)), + (30, 40, 50, ch, batch))), 𝐱) @test size(g[1]) == (30, 40, 50, ch, batch) end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index 2d00b4ff..c64396bc 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -71,7 +71,7 @@ end end @testset "2D OperatorConv" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Dense(1, 64), @@ -87,7 +87,7 @@ end end @testset "permuted 2D OperatorConv" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Conv((1, 1), 1 => 64), @@ -104,7 +104,7 @@ end end @testset "2D OperatorKernel" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Dense(1, 64), @@ -119,7 +119,7 @@ end end @testset "permuted 2D OperatorKernel" begin - modes = (16, 16) + modes = (10, 10) ch = 64 => 64 m = Chain(Conv((1, 1), 1 => 64),