Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

Commit

Permalink
draft for MWT_CZ1d
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Aug 31, 2022
1 parent 15c6b06 commit 37e5d48
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 86 deletions.
10 changes: 10 additions & 0 deletions src/Transform/polynomials.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
function get_filter(base::Symbol, k)
if base == :legendre
return legendre_filter(k)
elseif base == :chebyshev
return chebyshev_filter(k)
else
throw(ArgumentError("base must be one of :legendre or :chebyshev."))
end
end

function legendre_ϕ_ψ(k)
# TODO: row-major -> column major
ϕ_coefs = zeros(k, k)
Expand Down
159 changes: 73 additions & 86 deletions src/Transform/wavelet_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,93 +51,80 @@ function (l::SparseKernel)(X::AbstractArray)
end


# struct MWT_CZ1d

# end

# function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform)

# end

# class MWT_CZ1d(nn.Module):
# def __init__(self,
# k = 3, alpha = 5,
# L = 0, c = 1,
# base = 'legendre',
# initializer = None,
# **kwargs):
# super(MWT_CZ1d, self).__init__()

# self.k = k
# self.L = L
# H0, H1, G0, G1, PHI0, PHI1 = get_filter(base, k)
# H0r = H0@PHI0
# G0r = G0@PHI0
# H1r = H1@PHI1
# G1r = G1@PHI1

# H0r[np.abs(H0r)<1e-8]=0
# H1r[np.abs(H1r)<1e-8]=0
# G0r[np.abs(G0r)<1e-8]=0
# G1r[np.abs(G1r)<1e-8]=0

# self.A = sparseKernelFT1d(k, alpha, c)
# self.B = sparseKernelFT1d(k, alpha, c)
# self.C = sparseKernelFT1d(k, alpha, c)

# self.T0 = nn.Linear(k, k)

# self.register_buffer('ec_s', torch.Tensor(
# np.concatenate((H0.T, H1.T), axis=0)))
# self.register_buffer('ec_d', torch.Tensor(
# np.concatenate((G0.T, G1.T), axis=0)))

# self.register_buffer('rc_e', torch.Tensor(
# np.concatenate((H0r, G0r), axis=0)))
# self.register_buffer('rc_o', torch.Tensor(
# np.concatenate((H1r, G1r), axis=0)))


# def forward(self, x):

# B, N, c, ich = x.shape # (B, N, k)
# ns = math.floor(np.log2(N))

# Ud = torch.jit.annotate(List[Tensor], [])
# Us = torch.jit.annotate(List[Tensor], [])
# # decompose
# for i in range(ns-self.L):
# d, x = self.wavelet_transform(x)
# Ud += [self.A(d) + self.B(x)]
# Us += [self.C(d)]
# x = self.T0(x) # coarsest scale transform

# # reconstruct
# for i in range(ns-1-self.L,-1,-1):
# x = x + Us[i]
# x = torch.cat((x, Ud[i]), -1)
# x = self.evenOdd(x)
# return x


# def wavelet_transform(self, x):
# xa = torch.cat([x[:, ::2, :, :],
# x[:, 1::2, :, :],
# ], -1)
# d = torch.matmul(xa, self.ec_d)
# s = torch.matmul(xa, self.ec_s)
# return d, s


# def evenOdd(self, x):

# B, N, c, ich = x.shape # (B, N, c, k)
# assert ich == 2*self.k
# x_e = torch.matmul(x, self.rc_e)
# x_o = torch.matmul(x, self.rc_o)

struct MWT_CZ1d{T,S,R,Q,P}
k::Int
L::Int
A::T
B::S
C::R
T0::Q
ec_s::P
ec_d::P
rc_e::P
rc_o::P
end

function MWT_CZ1d(k::Int=3, α::Int=5, L::Int=0, c::Int=1; base::Symbol=:legendre, init=Flux.glorot_uniform)
H0, H1, G0, G1, ϕ0, ϕ1 = get_filter(base, k)
H0r = zero_out!(H0 * ϕ0)
G0r = zero_out!(G0 * ϕ0)
H1r = zero_out!(H1 * ϕ1)
G1r = zero_out!(G1 * ϕ1)

dim = c*k
A = SpectralConv(dim=>dim, (α,); init=init)
B = SpectralConv(dim=>dim, (α,); init=init)
C = SpectralConv(dim=>dim, (α,); init=init)
T0 = Dense(k, k)

ec_s = vcat(H0', H1')
ec_d = vcat(G0', G1')
rc_e = vcat(H0r, G0r)
rc_o = vcat(H1r, G1r)
return MWT_CZ1d(k, L, A, B, C, T0, ec_s, ec_d, rc_e, rc_o)
end

function wavelet_transform(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
N = size(X, 3)
Xa = vcat(view(X, :, :, 1:2:N, :), view(X, :, :, 2:2:N, :))
d = NNlib.batched_mul(Xa, l.ec_d)
s = NNlib.batched_mul(Xa, l.ec_s)
return d, s
end

function even_odd(l::MWT_CZ1d, X::AbstractArray{T,4}) where {T}
bch_sz, N, dims_r... = reverse(size(X))
dims = reverse(dims_r)
@assert dims[1] == 2*l.k
Xₑ = NNlib.batched_mul(X, l.rc_e)
Xₒ = NNlib.batched_mul(X, l.rc_o)
# x = torch.zeros(B, N*2, c, self.k,
# device = x.device)
# x[..., ::2, :, :] = x_e
# x[..., 1::2, :, :] = x_o
# return x
return X
end

function (l::MWT_CZ1d)(X::T) where {T<:AbstractArray}
bch_sz, N, dims_r... = reverse(size(X))
ns = floor(log2(N))
stop = ns - l.L

# decompose
Ud = T[]
Us = T[]
for i in 1:stop
d, X = wavelet_transform(l, X)
push!(Ud, l.A(d)+l.B(d))
push!(Us, l.C(d))
end
X = l.T0(X)

# reconstruct
for i in stop:-1:1
X += Us[i]
X = vcat(X, Ud[i]) # x = torch.cat((x, Ud[i]), -1)
X = even_odd(l, X)
end
return X
end

0 comments on commit 37e5d48

Please sign in to comment.