diff --git a/src/Transform/chebyshev_transform.jl b/src/Transform/chebyshev_transform.jl index 6e7d6c6a..d10b5e4c 100644 --- a/src/Transform/chebyshev_transform.jl +++ b/src/Transform/chebyshev_transform.jl @@ -14,7 +14,7 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray) return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch] end -function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M) where {N} +function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray, M::NTuple{N, Int64}) where {N} normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1))) return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch] end diff --git a/src/Transform/fourier_transform.jl b/src/Transform/fourier_transform.jl index de1c8303..b59ac9ac 100644 --- a/src/Transform/fourier_transform.jl +++ b/src/Transform/fourier_transform.jl @@ -16,6 +16,6 @@ end truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft) -function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray, M) +function inverse(ft::FourierTransform{N}, 𝐱_fft::AbstractArray, M::NTuple{N, Int64}) where {N} return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch] end