Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
add WaveletTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Aug 31, 2022
1 parent 4e244b1 commit 09d085c
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 171 deletions.
153 changes: 28 additions & 125 deletions src/Transform/wavelet_transform.jl
Original file line number Diff line number Diff line change
@@ -1,130 +1,33 @@
export
SparseKernel,
SparseKernel1D,
SparseKernel2D,
SparseKernel3D


struct SparseKernel{N,T,S}
conv_blk::T
out_weight::S
end

function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
input_dim, emb_dim = ch
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
end

function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k
emb_dim = 128
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
end

function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
end

function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
end

Flux.@functor SparseKernel

function (l::SparseKernel)(X::AbstractArray)
bch_sz, _, dims_r... = reverse(size(X))
dims = reverse(dims_r)

X_ = l.conv_blk(X) # (dims..., emb_dims, B)
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
return collect(Y)
end


struct MWT_CZ1d{T,S,R,Q,P}
k::Int
L::Int
A::T
B::S
C::R
T0::Q
ec_s::P
ec_d::P
rc_e::P
rc_o::P
end

function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
H0r = zero_out!(H0 * Φ0)
G0r = zero_out!(G0 * Φ0)
H1r = zero_out!(H1 * Φ1)
G1r = zero_out!(G1 * Φ1)

dim = c*k
A = SpectralConv(dim=>dim, (α,); init=init)
B = SpectralConv(dim=>dim, (α,); init=init)
C = SpectralConv(dim=>dim, (α,); init=init)
T0 = Dense(k, k)

ec_s = vcat(H0', H1')
ec_d = vcat(G0', G1')
rc_e = vcat(H0r, G0r)
rc_o = vcat(H1r, G1r)
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
end

function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
N = size(X, 3)
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
d = NNlib.batched_mul(Xa, l.ec_d)
s = NNlib.batched_mul(Xa, l.ec_s)
export WaveletTransform

struct WaveletTransform{N, S}<:AbstractTransform
ec_d
ec_s
modes::NTuple{N, S} # N == ndims(x)
end

Base.ndims(::WaveletTransform{N}) where {N} = N

function transform(wt::WaveletTransform, 𝐱::AbstractArray)
N = size(X, ndims(wt)-1)
# 1d
Xa = vcat(view(𝐱, :, :, 1:2:N, :), view(𝐱, :, :, 2:2:N, :))
# 2d
# Xa = vcat(
# view(𝐱, :, :, 1:2:N, 1:2:N, :),
# view(𝐱, :, :, 1:2:N, 2:2:N, :),
# view(𝐱, :, :, 2:2:N, 1:2:N, :),
# view(𝐱, :, :, 2:2:N, 2:2:N, :),
# )
d = NNlib.batched_mul(Xa, wt.ec_d)
s = NNlib.batched_mul(Xa, wt.ec_s)
return d, s
end

function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
bch_sz, N, dims_r... = reverse(size(X))
dims = reverse(dims_r)
@assert dims[1] == 2*l.k
Xₑ = NNlib.batched_mul(X, l.rc_e)
Xₒ = NNlib.batched_mul(X, l.rc_o)
# x = torch.zeros(B, N*2, c, self.k,
# device = x.device)
# x[..., ::2, :, :] = x_e
# x[..., 1::2, :, :] = x_o
return X
function inverse(wt::WaveletTransform, 𝐱_fwt::AbstractArray)

end

function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
bch_sz, N, dims_r... = reverse(size(X))
ns = floor(log2(N))
stop = ns - l.L

# decompose
Ud = T[]
Us = T[]
for i in 1:stop
d, X = wavelet_transform(l, X)
push!(Ud, l.A(d)+l.B(d))
push!(Us, l.C(d))
end
X = l.T0(X)

# reconstruct
for i in stop:-1:1
X += Us[i]
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
X = even_odd(l, X)
end
return X
end
# function truncate_modes(wt::WaveletTransform, 𝐱_fft::AbstractArray)
# return view(𝐱_fft, map(d->1:d, wt.modes)..., :, :) # [ft.modes..., in_chs, batch]
# end
148 changes: 145 additions & 3 deletions src/operator_kernel.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
export
OperatorConv,
SpectralConv,
OperatorKernel
OperatorConv,
SpectralConv,
OperatorKernel,
SparseKernel,
SparseKernel1D,
SparseKernel2D,
SparseKernel3D,
MWT_CZ1d

struct OperatorConv{P, T, S, TT}
weight::T
Expand Down Expand Up @@ -180,6 +185,143 @@ function (m::OperatorKernel)(𝐱)
return m.σ.(m.linear(𝐱) + m.conv(𝐱))
end

"""
SparseKernel(κ, ch, σ=identity)
Sparse kernel layer.
## Arguments
* `κ`: A neural network layer for approximation, e.g. a `Dense` layer or a MLP.
* `ch`: Channel size for linear transform, e.g. `32`.
* `σ`: Activation function.
"""
struct SparseKernel{N,T,S}
conv_blk::T
out_weight::S
end

function SparseKernel(filter::NTuple{N,T}, ch::Pair{S, S}; init=Flux.glorot_uniform) where {N,T,S}
input_dim, emb_dim = ch
conv = Conv(filter, input_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{N,typeof(conv),typeof(W_out)}(conv, W_out)
end

function SparseKernel1D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k
emb_dim = 128
return SparseKernel((3, ), input_dim=>emb_dim; init=init)
end

function SparseKernel2D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
return SparseKernel((3, 3), input_dim=>emb_dim; init=init)
end

function SparseKernel3D(k::Int, α, c::Int=1; init=Flux.glorot_uniform)
input_dim = c*k^2
emb_dim = α*k^2
conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init)
W_out = Dense(emb_dim, input_dim; init=init)
return SparseKernel{3,typeof(conv),typeof(W_out)}(conv, W_out)
end

Flux.@functor SparseKernel

function (l::SparseKernel)(X::AbstractArray)
bch_sz, _, dims_r... = reverse(size(X))
dims = reverse(dims_r)

X_ = l.conv_blk(X) # (dims..., emb_dims, B)
X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B)
Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B)
Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B)
return collect(Y)
end


struct MWT_CZ1d{T,S,R,Q,P}
k::Int
L::Int
A::T
B::S
C::R
T0::Q
ec_s::P
ec_d::P
rc_e::P
rc_o::P
end

function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
H0, H1, G0, G1, Φ0, Φ1 = get_filter(base, k)
H0r = zero_out!(H0 * Φ0)
G0r = zero_out!(G0 * Φ0)
H1r = zero_out!(H1 * Φ1)
G1r = zero_out!(G1 * Φ1)

dim = c*k
A = SpectralConv(dim=>dim, (α,); init=init)
B = SpectralConv(dim=>dim, (α,); init=init)
C = SpectralConv(dim=>dim, (α,); init=init)
T0 = Dense(k, k)

ec_s = vcat(H0', H1')
ec_d = vcat(G0', G1')
rc_e = vcat(H0r, G0r)
rc_o = vcat(H1r, G1r)
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
end

function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
N = size(X, 3)
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
d = NNlib.batched_mul(Xa, l.ec_d)
s = NNlib.batched_mul(Xa, l.ec_s)
return d, s
end

function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
bch_sz, N, dims_r... = reverse(size(X))
dims = reverse(dims_r)
@assert dims[1] == 2*l.k
Y = similar(X, bch_sz, 2N, l.c, l.k)
view(Y, :, :, 1:2:N, :) .= NNlib.batched_mul(X, l.rc_e)
view(Y, :, :, 2:2:N, :) .= NNlib.batched_mul(X, l.rc_o)
return Y
end

function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
bch_sz, N, dims_r... = reverse(size(X))
ns = floor(log2(N))
stop = ns - l.L

# decompose
Ud = T[]
Us = T[]
for i in 1:stop
d, X = wavelet_transform(l, X)
push!(Ud, l.A(d)+l.B(d))
push!(Us, l.C(d))
end
X = l.T0(X)

# reconstruct
for i in stop:-1:1
X += Us[i]
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
X = even_odd(l, X)
end
return X
end

# function Base.show(io::IO, l::MWT_CZ1d)
# print(io, "MWT_CZ($(l.in_channel) => $(l.out_channel), $(l.transform.modes), $(nameof(typeof(l.transform))), permuted=$P)")
# end


#########
# utils #
#########
Expand Down
File renamed without changes.
63 changes: 20 additions & 43 deletions test/Transform/wavelet_transform.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,30 @@
@testset "SparseKernel" begin
@testset "wavelet transform" begin
𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7

wt = WaveletTransform((3, 4, 5))

@test size(transform(wt, 𝐱)) == (30, 40, 50, 6, 7)
@test size(truncate_modes(wt, transform(wt, 𝐱))) == (3, 4, 5, 6, 7)
@test size(inverse(wt, truncate_modes(wt, transform(wt, 𝐱)))) == (3, 4, 5, 6, 7)
end

@testset "MWT_CZ" begin
T = Float32
k = 3
batch_size = 32

@testset "1D SparseKernel" begin
α = 4
c = 1
in_chs = 20
X = rand(T, in_chs, c*k, batch_size)
@testset "MWT_CZ1d" begin
mwt = MWT_CZ1d()

l1 = SparseKernel1D(k, α, c)
Y = l1(X)
@test l1 isa SparseKernel{1}
@test size(Y) == size(X)
# base functions
wavelet_transform(mwt, )
even_odd(mwt, )

gs = gradient(()->sum(l1(X)), Flux.params(l1))
@test length(gs.grads) == 4
end
# forward
Y = mwt(X)

@testset "2D SparseKernel" begin
α = 4
c = 3
Nx = 5
Ny = 7
X = rand(T, Nx, Ny, c*k^2, batch_size)

l2 = SparseKernel2D(k, α, c)
Y = l2(X)
@test l2 isa SparseKernel{2}
@test size(Y) == size(X)

gs = gradient(()->sum(l2(X)), Flux.params(l2))
@test length(gs.grads) == 4
# backward
g = gradient()
end

@testset "3D SparseKernel" begin
α = 4
c = 3
Nx = 5
Ny = 7
Nz = 13
X = rand(T, Nx, Ny, Nz, α*k^2, batch_size)

l3 = SparseKernel3D(k, α, c)
Y = l3(X)
@test l3 isa SparseKernel{3}
@test size(Y) == (Nx, Ny, Nz, c*k^2, batch_size)

gs = gradient(()->sum(l3(X)), Flux.params(l3))
@test length(gs.grads) == 4
end
end
Loading

0 comments on commit 09d085c

Please sign in to comment.