Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tknopp committed Nov 29, 2022
1 parent cdea34e commit fed68b8
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
end

function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N}
function inverse(t::ChebyshevTransform, 𝐱̂::AbstractArray{T, N},
M::NTuple{N, Int64}) where {T, N}
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
end
Expand Down
3 changes: 2 additions & 1 deletion src/Transform/fourier_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ end

truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft)

function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T,N}, M::NTuple{N, Int64}) where {T,N}
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray{T, N},
M::NTuple{N, Int64}) where {T, N}
return real(irfft(𝐱_fft, M[1], 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]
end
3 changes: 2 additions & 1 deletion test/Transform/chebyshev_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
@test ndims(t) == 3
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)), size(𝐱))) ==
(3, 4, 5, ch, batch)

g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)), size(𝐱))), 𝐱)
@test size(g[1]) == (30, 40, 50, ch, batch)
Expand Down
16 changes: 11 additions & 5 deletions test/Transform/fourier_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@

@test size(transform(ft, 𝐱)) == (16, 40, 50, ch, batch)
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
@test size(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, 𝐱)), size(transform(ft, 𝐱)) ),
size(𝐱))) == (30, 40, 50, ch, batch)
@test size(inverse(ft,
NeuralOperators.pad_modes(truncate_modes(ft, transform(ft, 𝐱)),
size(transform(ft, 𝐱))),
size(𝐱))) == (30, 40, 50, ch, batch)

g = Zygote.gradient(x -> sum(inverse(ft, NeuralOperators.pad_modes( truncate_modes(ft, transform(ft, x)),
(16, 40, 50, ch, batch) ), (30, 40, 50, ch, batch) )), 𝐱)
g = Zygote.gradient(x -> sum(inverse(ft,
NeuralOperators.pad_modes(truncate_modes(ft,
transform(ft,
x)),
(16, 40, 50, ch, batch)),
(30, 40, 50, ch, batch))), 𝐱)
@test size(g[1]) == (30, 40, 50, ch, batch)
end
end

0 comments on commit fed68b8

Please sign in to comment.