diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index 2b02d2a3..2d00b4ff 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -142,3 +142,21 @@ end @test SpectralConv(ch, modes) isa OperatorConv @test SpectralConv(ch, modes).transform isa FourierTransform end + +@testset "1D OperatorConv with ChebyshevTransform" begin + modes = (16,) + ch = 64 => 128 + + m = Chain(Dense(2, 64), + OperatorConv(ch, modes, ChebyshevTransform)) + @test ndims(OperatorConv(ch, modes, ChebyshevTransform)) == 1 + @test repr(OperatorConv(ch, modes, ChebyshevTransform)) == + "OperatorConv(64 => 128, (16,), ChebyshevTransform, permuted=false)" + + 𝐱 = rand(Float32, 2, 1024, 5) + @test size(m(𝐱)) == (128, 1024, 5) + + loss(x, y) = Flux.mse(m(x), y) + data = [(𝐱, rand(Float32, 128, 1024, 5))] + Flux.train!(loss, Flux.params(m), data, Flux.Adam()) +end