From fb3a7e0b11454bc92437f21569aaef72217affb7 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sat, 26 Mar 2022 23:36:25 +0800 Subject: [PATCH] add WaveletTransform --- src/Transform/wavelet_transform.jl | 153 +++++----------------------- src/operator_kernel.jl | 148 ++++++++++++++++++++++++++- test/{ => Transform}/polynomials.jl | 0 test/Transform/wavelet_transform.jl | 63 ++++-------- test/operator_kernel.jl | 54 ++++++++++ 5 files changed, 247 insertions(+), 171 deletions(-) rename test/{ => Transform}/polynomials.jl (100%) diff --git a/src/Transform/wavelet_transform.jl b/src/Transform/wavelet_transform.jl index 26749c73..41d54edb 100644 --- a/src/Transform/wavelet_transform.jl +++ b/src/Transform/wavelet_transform.jl @@ -1,130 +1,33 @@ -export - SparseKernel, - SparseKernel1D, - SparseKernel2D, - SparseKernel3D - - -struct SparseKernel{N,T,S} - conv_blk::T - out_weight::S -end - -function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} - input_dim, emb_dim = ch - conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) -end - -function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k - emb_dim = 128 - return SparseKernel((3, ), input_dim=>emb_dim; init=init) -end - -function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - return SparseKernel((3, 3), input_dim=>emb_dim; init=init) -end - -function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) - input_dim = c*k^2 - emb_dim = α*k^2 - conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) - W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) -end - -Flux.@functor SparseKernel - -function (l::SparseKernel)(X::AbstractArray) - bch_sz, _, dims_r... = reverse(size(X)) - dims = reverse(dims_r) - - X_ = l.conv_blk(X) # (dims..., emb_dims, B) - X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) - Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) - Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) - return collect(Y) -end - - -struct MWT_CZ1d{T,S,R,Q,P} - k::Int - L::Int - A::T - B::S - C::R - T0::Q - ec_s::P - ec_d::P - rc_e::P - rc_o::P -end - -function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) - H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) - H0r = zero_out!(H0 * Φ0) - G0r = zero_out!(G0 * Φ0) - H1r = zero_out!(H1 * Φ1) - G1r = zero_out!(G1 * Φ1) - - dim = c*k - A = SpectralConv(dim=>dim, (α,); init=init) - B = SpectralConv(dim=>dim, (α,); init=init) - C = SpectralConv(dim=>dim, (α,); init=init) - T0 = Dense(k, k) - - ec_s = vcat(H0', H1') - ec_d = vcat(G0', G1') - rc_e = vcat(H0r, G0r) - rc_o = vcat(H1r, G1r) - return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) -end - -function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} - N = size(X, 3) - Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) - d = NNlib.batched_mul(Xa, l.ec_d) - s = NNlib.batched_mul(Xa, l.ec_s) +export WaveletTransform + +struct WaveletTransform{N, S}<:AbstractTransform + ec_d + ec_s + modes::NTuple{N, S} # N == ndims(x) +end + +Base.ndims(::WaveletTransform{N}) where {N} = N + +function transform(wt::WaveletTransform, 𝐱::AbstractArray) + N = size(X, ndims(wt)-1) + # 1d + Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :)) + # 2d + # Xa = vcat( + # view(𝐱, :, :, 1:2:N, 1:2:N, :), + # view(𝐱, :, :, 1:2:N, 2:2:N, :), + # view(𝐱, :, :, 2:2:N, 1:2:N, :), + # view(𝐱, :, :, 2:2:N, 2:2:N, :), + # ) + d = NNlib.batched_mul(Xa, wt.ec_d) + s = NNlib.batched_mul(Xa, wt.ec_s) return d, s end -function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} - bch_sz, N, dims_r... = reverse(size(X)) - dims = reverse(dims_r) - @assert dims[1] == 2*l.k - Xₑ = NNlib.batched_mul(X, l.rc_e) - Xₒ = NNlib.batched_mul(X, l.rc_o) -# x = torch.zeros(B, N*2, c, self.k, -# device = x.device) -# x[..., ::2, :, :] = x_e -# x[..., 1::2, :, :] = x_o - return X +function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray) + end -function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} - bch_sz, N, dims_r... = reverse(size(X)) - ns = floor(log2(N)) - stop = ns - l.L - - # decompose - Ud = T[] - Us = T[] - for i in 1:stop - d, X = wavelet_transform(l, X) - push!(Ud, l.A(d)+l.B(d)) - push!(Us, l.C(d)) - end - X = l.T0(X) - - # reconstruct - for i in stop:-1:1 - X += Us[i] - X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) - X = even_odd(l, X) - end - return X -end +# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray) +# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch] +# end diff --git a/src/operator_kernel.jl b/src/operator_kernel.jl index 7b1f3eb7..625cec0c 100644 --- a/src/operator_kernel.jl +++ b/src/operator_kernel.jl @@ -1,7 +1,12 @@ export - OperatorConv, - SpectralConv, - OperatorKernel + OperatorConv, + SpectralConv, + OperatorKernel, + SparseKernel, + SparseKernel1D, + SparseKernel2D, + SparseKernel3D, + MWT_CZ1d struct OperatorConv{P, T, S, TT} weight::T @@ -180,6 +185,143 @@ function (m::OperatorKernel)(𝐱) return m.σ.(m.linear(𝐱) + m.conv(𝐱)) end +""" + SparseKernel(κ, ch, σ=identity) + +Sparse kernel layer. + +## Arguments + +* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP. +* `ch`: Channel size for linear transform, e.g. `32`. +* `σ`: Activation function. +""" +struct SparseKernel{N,T,S} + conv_blk::T + out_weight::S +end + +function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S} + input_dim, emb_dim = ch + conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out) +end + +function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k + emb_dim = 128 + return SparseKernel((3, ), input_dim=>emb_dim; init=init) +end + +function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + return SparseKernel((3, 3), input_dim=>emb_dim; init=init) +end + +function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out) +end + +Flux.@functor SparseKernel + +function (l::SparseKernel)(X::AbstractArray) + bch_sz, _, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + + X_ = l.conv_blk(X) # (dims..., emb_dims, B) + X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) + Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) + Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) + return collect(Y) +end + + +struct MWT_CZ1d{T,S,R,Q,P} + k::Int + L::Int + A::T + B::S + C::R + T0::Q + ec_s::P + ec_d::P + rc_e::P + rc_o::P +end + +function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform) + H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k) + H0r = zero_out!(H0 * Φ0) + G0r = zero_out!(G0 * Φ0) + H1r = zero_out!(H1 * Φ1) + G1r = zero_out!(G1 * Φ1) + + dim = c*k + A = SpectralConv(dim=>dim, (α,); init=init) + B = SpectralConv(dim=>dim, (α,); init=init) + C = SpectralConv(dim=>dim, (α,); init=init) + T0 = Dense(k, k) + + ec_s = vcat(H0', H1') + ec_d = vcat(G0', G1') + rc_e = vcat(H0r, G0r) + rc_o = vcat(H1r, G1r) + return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o) +end + +function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + N = size(X, 3) + Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :)) + d = NNlib.batched_mul(Xa, l.ec_d) + s = NNlib.batched_mul(Xa, l.ec_s) + return d, s +end + +function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T} + bch_sz, N, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + @assert dims[1] == 2*l.k + Y = similar(X, bch_sz, 2N, l.c, l.k) + view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e) + view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o) + return Y +end + +function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray} + bch_sz, N, dims_r... = reverse(size(X)) + ns = floor(log2(N)) + stop = ns - l.L + + # decompose + Ud = T[] + Us = T[] + for i in 1:stop + d, X = wavelet_transform(l, X) + push!(Ud, l.A(d)+l.B(d)) + push!(Us, l.C(d)) + end + X = l.T0(X) + + # reconstruct + for i in stop:-1:1 + X += Us[i] + X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1) + X = even_odd(l, X) + end + return X +end + +# function Base.show(io::IO, l::MWT_CZ1d) +# print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)") +# end + + ######### # utils # ######### diff --git a/test/polynomials.jl b/test/Transform/polynomials.jl similarity index 100% rename from test/polynomials.jl rename to test/Transform/polynomials.jl diff --git a/test/Transform/wavelet_transform.jl b/test/Transform/wavelet_transform.jl index 726727eb..48705bf3 100644 --- a/test/Transform/wavelet_transform.jl +++ b/test/Transform/wavelet_transform.jl @@ -1,53 +1,30 @@ -@testset "SparseKernel" begin +@testset "wavelet transform" begin + 𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7 + + wt = WaveletTransform((3, 4, 5)) + + @test size(transform(wt, 𝐱)) == (30, 40, 50, 6, 7) + @test size(truncate_modes(wt, transform(wt, 𝐱))) == (3, 4, 5, 6, 7) + @test size(inverse(wt, truncate_modes(wt, transform(wt, 𝐱)))) == (3, 4, 5, 6, 7) +end + +@testset "MWT_CZ" begin T = Float32 k = 3 batch_size = 32 - @testset "1D SparseKernel" begin - α = 4 - c = 1 - in_chs = 20 - X = rand(T, in_chs, c*k, batch_size) + @testset "MWT_CZ1d" begin + mwt = MWT_CZ1d() - l1 = SparseKernel1D(k, α, c) - Y = l1(X) - @test l1 isa SparseKernel{1} - @test size(Y) == size(X) + # base functions + wavelet_transform(mwt, ) + even_odd(mwt, ) - gs = gradient(()->sum(l1(X)), Flux.params(l1)) - @test length(gs.grads) == 4 - end + # forward + Y = mwt(X) - @testset "2D SparseKernel" begin - α = 4 - c = 3 - Nx = 5 - Ny = 7 - X = rand(T, Nx, Ny, c*k^2, batch_size) - - l2 = SparseKernel2D(k, α, c) - Y = l2(X) - @test l2 isa SparseKernel{2} - @test size(Y) == size(X) - - gs = gradient(()->sum(l2(X)), Flux.params(l2)) - @test length(gs.grads) == 4 + # backward + g = gradient() end - @testset "3D SparseKernel" begin - α = 4 - c = 3 - Nx = 5 - Ny = 7 - Nz = 13 - X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) - - l3 = SparseKernel3D(k, α, c) - Y = l3(X) - @test l3 isa SparseKernel{3} - @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) - - gs = gradient(()->sum(l3(X)), Flux.params(l3)) - @test length(gs.grads) == 4 - end end diff --git a/test/operator_kernel.jl b/test/operator_kernel.jl index 2b02d2a3..be20b01b 100644 --- a/test/operator_kernel.jl +++ b/test/operator_kernel.jl @@ -142,3 +142,57 @@ end @test SpectralConv(ch, modes) isa OperatorConv @test SpectralConv(ch, modes).transform isa FourierTransform end + +@testset "SparseKernel" begin + T = Float32 + k = 3 + batch_size = 32 + + @testset "1D SparseKernel" begin + α = 4 + c = 1 + in_chs = 20 + X = rand(T, in_chs, c*k, batch_size) + + l1 = SparseKernel1D(k, α, c) + Y = l1(X) + @test l1 isa SparseKernel{1} + @test size(Y) == size(X) + + gs = gradient(()->sum(l1(X)), Flux.params(l1)) + @test length(gs.grads) == 4 + end + + @testset "2D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + X = rand(T, Nx, Ny, c*k^2, batch_size) + + l2 = SparseKernel2D(k, α, c) + Y = l2(X) + @test l2 isa SparseKernel{2} + @test size(Y) == size(X) + + gs = gradient(()->sum(l2(X)), Flux.params(l2)) + @test length(gs.grads) == 4 + end + + @testset "3D SparseKernel" begin + α = 4 + c = 3 + Nx = 5 + Ny = 7 + Nz = 13 + X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) + + l3 = SparseKernel3D(k, α, c) + Y = l3(X) + @test l3 isa SparseKernel{3} + @test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size) + + gs = gradient(()->sum(l3(X)), Flux.params(l3)) + @test length(gs.grads) == 4 + end +end