diff --git a/Project.toml b/Project.toml index d8aaa3b..1f75ef4 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index 0f767b7..303b80a 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -7,7 +7,11 @@ module NeuralOperators using KernelAbstractions using Zygote using ChainRulesCore + using Polynomials + using SpecialPolynomials + include("utils.jl") + include("polynomials.jl") include("fourier.jl") include("wavelet.jl") include("model.jl") diff --git a/src/polynomials.jl b/src/polynomials.jl new file mode 100644 index 0000000..4670caa --- /dev/null +++ b/src/polynomials.jl @@ -0,0 +1,198 @@ +function legendre_ϕ_ψ(k) + # TODO: row-major -> column major + ϕ_coefs = zeros(k, k) + ϕ_2x_coefs = zeros(k, k) + + p = Polynomial([-1, 2]) # 2x-1 + p2 = Polynomial([-1, 4]) # 4x-1 + + for ki in 0:(k-1) + l = convert(Polynomial, gen_poly(Legendre, ki)) # Legendre of n=ki + ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2*ki+1) .* coeffs(l(p)) + ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2*(2*ki+1)) .* coeffs(l(p2)) + end + + ψ1_coefs .= ϕ_2x_coefs + ψ2_coefs = zeros(k, k) + for ki in 0:(k-1) + for i in 0:(k-1) + a = ϕ_2x_coefs[ki+1, 1:(ki+1)] + b = ϕ_coefs[i+1, 1:(i+1)] + proj_ = proj_factor(a, b) + view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :) + view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ϕ_coefs, i+1, :) + end + + for j in 0:(k-1) + a = ϕ_2x_coefs[ki+1, 1:(ki+1)] + b = ψ1_coefs[j+1, :] + proj_ = proj_factor(a, b) + view(ψ1_coefs, ki+1, :) .-= proj_ .* view(ψ1_coefs, j+1, :) + view(ψ2_coefs, ki+1, :) .-= proj_ .* view(ψ2_coefs, j+1, :) + end + + a = ψ1_coefs[ki+1, :] + norm1 = proj_factor(a, a) + + a = ψ2_coefs[ki+1, :] + norm2 = proj_factor(a, a, complement=true) + norm_ = sqrt(norm1 + norm2) + ψ1_coefs[ki+1, :] ./= norm_ + ψ2_coefs[ki+1, :] ./= norm_ + zero_out!(ψ1_coefs) + zero_out!(ψ2_coefs) + end + + ϕ = [Polynomial(ϕ_coefs[i,:]) for i in 1:k] + ψ1 = [Polynomial(ψ1_coefs[i,:]) for i in 1:k] + ψ2 = [Polynomial(ψ2_coefs[i,:]) for i in 1:k] + + return ϕ, ψ1, ψ2 +end + +# function chebyshev_ϕ_ψ(k) +# ϕ_coefs = zeros(k, k) +# ϕ_2x_coefs = zeros(k, k) + +# p = Polynomial([-1, 2]) # 2x-1 +# p2 = Polynomial([-1, 4]) # 4x-1 + +# for ki in 0:(k-1) +# if ki == 0 +# ϕ_coefs[ki+1, 1:(ki+1)] .= sqrt(2/π) +# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(4/π) +# else +# c = convert(Polynomial, gen_poly(Chebyshev, ki)) # Chebyshev of n=ki +# ϕ_coefs[ki+1, 1:(ki+1)] .= 2/sqrt(π) .* coeffs(c(p)) +# ϕ_2x_coefs[ki+1, 1:(ki+1)] .= sqrt(2) * 2/sqrt(π) .* coeffs(c(p2)) +# end +# end + +# ϕ = [ϕ_(ϕ_coefs[i, :]) for i in 1:k] + +# k_use = 2k + +# # phi = [partial(phi_, phi_coeff[i,:]) for i in range(k)] + +# # x = Symbol('x') +# # kUse = 2*k +# # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() +# # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) +# # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) +# # # not needed for our purpose here, we use even k always to avoid +# # wm = np.pi / kUse / 2 + +# # psi1_coeff = np.zeros((k, k)) +# # psi2_coeff = np.zeros((k, k)) + +# # psi1 = [[] for _ in range(k)] +# # psi2 = [[] for _ in range(k)] + +# # for ki in range(k): +# # psi1_coeff[ki,:] = phi_2x_coeff[ki,:] +# # for i in range(k): +# # proj_ = (wm * phi[i](x_m) * np.sqrt(2)* phi[ki](2*x_m)).sum() +# # psi1_coeff[ki,:] -= proj_ * phi_coeff[i,:] +# # psi2_coeff[ki,:] -= proj_ * phi_coeff[i,:] + +# # for j in range(ki): +# # proj_ = (wm * psi1[j](x_m) * np.sqrt(2) * phi[ki](2*x_m)).sum() +# # psi1_coeff[ki,:] -= proj_ * psi1_coeff[j,:] +# # psi2_coeff[ki,:] -= proj_ * psi2_coeff[j,:] + +# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5) +# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5, ub = 1) + +# # norm1 = (wm * psi1[ki](x_m) * psi1[ki](x_m)).sum() +# # norm2 = (wm * psi2[ki](x_m) * psi2[ki](x_m)).sum() + +# # norm_ = np.sqrt(norm1 + norm2) +# # psi1_coeff[ki,:] /= norm_ +# # psi2_coeff[ki,:] /= norm_ +# # psi1_coeff[np.abs(psi1_coeff)<1e-8] = 0 +# # psi2_coeff[np.abs(psi2_coeff)<1e-8] = 0 + +# # psi1[ki] = partial(phi_, psi1_coeff[ki,:], lb = 0, ub = 0.5+1e-16) +# # psi2[ki] = partial(phi_, psi2_coeff[ki,:], lb = 0.5+1e-16, ub = 1) + +# # return phi, psi1, psi2 +# end + +function legendre_filter(k) + # x = Symbol('x') + # H0 = np.zeros((k,k)) + # H1 = np.zeros((k,k)) + # G0 = np.zeros((k,k)) + # G1 = np.zeros((k,k)) + # PHI0 = np.zeros((k,k)) + # PHI1 = np.zeros((k,k)) + # phi, psi1, psi2 = get_phi_psi(k, base) + + # ---------------------------------------------------------- + + # roots = Poly(legendre(k, 2*x-1)).all_roots() + # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # wm = 1/k/legendreDer(k,2*x_m-1)/eval_legendre(k-1,2*x_m-1) + + # for ki in range(k): + # for kpi in range(k): + # H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() + # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() + # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() + # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() + + # PHI0 = np.eye(k) + # PHI1 = np.eye(k) + + # ---------------------------------------------------------- + + # H0[np.abs(H0)<1e-8] = 0 + # H1[np.abs(H1)<1e-8] = 0 + # G0[np.abs(G0)<1e-8] = 0 + # G1[np.abs(G1)<1e-8] = 0 + + # return H0, H1, G0, G1, PHI0, PHI1 +end + +function chebyshev_filter(k) + # x = Symbol('x') + # H0 = np.zeros((k,k)) + # H1 = np.zeros((k,k)) + # G0 = np.zeros((k,k)) + # G1 = np.zeros((k,k)) + # PHI0 = np.zeros((k,k)) + # PHI1 = np.zeros((k,k)) + # phi, psi1, psi2 = get_phi_psi(k, base) + + # ---------------------------------------------------------- + + # x = Symbol('x') + # kUse = 2*k + # roots = Poly(chebyshevt(kUse, 2*x-1)).all_roots() + # x_m = np.array([rt.evalf(20) for rt in roots]).astype(np.float64) + # # x_m[x_m==0.5] = 0.5 + 1e-8 # add small noise to avoid the case of 0.5 belonging to both phi(2x) and phi(2x-1) + # # not needed for our purpose here, we use even k always to avoid + # wm = np.pi / kUse / 2 + + # for ki in range(k): + # for kpi in range(k): + # H0[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki](x_m/2) * phi[kpi](x_m)).sum() + # G0[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, x_m/2) * phi[kpi](x_m)).sum() + # H1[ki, kpi] = 1/np.sqrt(2) * (wm * phi[ki]((x_m+1)/2) * phi[kpi](x_m)).sum() + # G1[ki, kpi] = 1/np.sqrt(2) * (wm * psi(psi1, psi2, ki, (x_m+1)/2) * phi[kpi](x_m)).sum() + + # PHI0[ki, kpi] = (wm * phi[ki](2*x_m) * phi[kpi](2*x_m)).sum() * 2 + # PHI1[ki, kpi] = (wm * phi[ki](2*x_m-1) * phi[kpi](2*x_m-1)).sum() * 2 + + # PHI0[np.abs(PHI0)<1e-8] = 0 + # PHI1[np.abs(PHI1)<1e-8] = 0 + + # ---------------------------------------------------------- + + # H0[np.abs(H0)<1e-8] = 0 + # H1[np.abs(H1)<1e-8] = 0 + # G0[np.abs(G0)<1e-8] = 0 + # G1[np.abs(G1)<1e-8] = 0 + + # return H0, H1, G0, G1, PHI0, PHI1 +end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..d9d5855 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,39 @@ +# function ϕ_(ϕ_coefs; lb::Real=0., ub::Real=1.) +# mask = +# return Polynomial(ϕ_coefs) +# end + +# def phi_(phi_c, x, lb = 0, ub = 1): +# mask = np.logical_or(xub) * 1.0 +# return np.polynomial.polynomial.Polynomial(phi_c)(x) * (1-mask) + +function ψ(ψ1, ψ2, i, inp) + mask = (inp ≤ 0.5) * 1.0 + return ψ1[i](inp) * mask + ψ2[i](inp) * (1-mask) +end + +zero_out!(x; tol=1e-8) = (x[abs.(x) .< tol] .= 0) + +function gen_poly(poly, n) + x = zeros(n+1) + x[end] = 1 + return poly(x) +end + +function convolve(a, b) + n = length(b) + y = similar(a, length(a)+n-1) + for i in 1:length(a) + y[i:(i+n-1)] .+= a[i] .* b + end + return y +end + +function proj_factor(a, b; complement::Bool=false) + prod_ = convolve(a, b) + zero_out!(prod_) + r = collect(1:length(prod_)) + s = complement ? (1 .- 0.5 .^ r) : (0.5 .^ r) + proj_ = sum(prod_ ./ r .* s) + return proj_ +end diff --git a/src/wavelet.jl b/src/wavelet.jl index 6dec15b..52e4fe2 100644 --- a/src/wavelet.jl +++ b/src/wavelet.jl @@ -1,24 +1,53 @@ -struct SparseKernel1d{T,S} +struct SparseKernel{T,S} k::Int conv_blk::S out_weight::T end -function SparseKernel1d(k::Int, c::Int=1; init=Flux.glorot_uniform) +function SparseKernel1d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) input_dim = c*k emb_dim = 128 conv = Conv((3,), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) W_out = Dense(emb_dim, input_dim; init=init) - return SparseKernel1d(k, conv, W_out) + return SparseKernel(k, conv, W_out) end -function (l::SparseKernel1d)(X::AbstractArray) - X_ = l.conv_blk(batched_transpose(X)) - Y = l.out_weight(batched_transpose(X_)) - return Y +function SparseKernel2d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3), input_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel(k, conv, W_out) +end + +function SparseKernel3d(k::Int, α, c::Int=1; init=Flux.glorot_uniform) + input_dim = c*k^2 + emb_dim = α*k^2 + conv = Conv((3, 3, 3), emb_dim=>emb_dim, relu; stride=1, pad=1, init=init) + W_out = Dense(emb_dim, input_dim; init=init) + return SparseKernel(k, conv, W_out) +end + +function (l::SparseKernel)(X::AbstractArray) + bch_sz, _, dims_r... = reverse(size(X)) + dims = reverse(dims_r) + + X_ = l.conv_blk(X) # (dims..., emb_dims, B) + X_ = reshape(X_, prod(dims), :, bch_sz) # (prod(dims), emb_dims, B) + Y = l.out_weight(batched_transpose(X_)) # (in_dims, prod(dims), B) + Y = reshape(batched_transpose(Y), dims..., :, bch_sz) # (dims..., in_dims, B) + return collect(Y) end +# struct MWT_CZ1d + +# end + +# function MWT_CZ1d(k::Int=3, c::Int=1; init=Flux.glorot_uniform) + +# end + # class MWT_CZ1d(nn.Module): # def __init__(self, # k = 3, alpha = 5, diff --git a/test/wavelet.jl b/test/wavelet.jl index 48642c1..d4e920d 100644 --- a/test/wavelet.jl +++ b/test/wavelet.jl @@ -1,13 +1,37 @@ using NeuralOperators +using CUDA +using Zygote + +CUDA.allowscalar(false) T = Float32 -k = 10 +k = 3 +batch_size = 32 + +α = 4 c = 1 in_chs = 20 -batch_size = 32 -l = NeuralOperators.SparseKernel1d(k, c) +l1 = NeuralOperators.SparseKernel1d(k, α, c) +X = rand(T, in_chs, c*k, batch_size) +Y = l1(X) +gradient(x->sum(l1(x)), X) + + +α = 4 +c = 3 +Nx = 5 +Ny = 7 + +l2 = NeuralOperators.SparseKernel2d(k, α, c) +X = rand(T, Nx, Ny, c*k^2, batch_size) +Y = l2(X) +gradient(x->sum(l2(x)), X) + +Nz = 13 -X = rand(T, c*k, in_chs, batch_size) -Y = l(X) +l3 = NeuralOperators.SparseKernel3d(k, α, c) +X = rand(T, Nx, Ny, Nz, α*k^2, batch_size) +Y = l3(X) +gradient(x->sum(l3(x)), X)