diff --git a/Project.toml b/Project.toml index 01fc241f..1ac81ed4 100644 --- a/Project.toml +++ b/Project.toml @@ -4,19 +4,25 @@ authors = ["JingYu Ning and contributors"] version = "0.1.0" [deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +CUDA = "3.3" +CUDAKernels = "0.3" DataDeps = "0.7" FFTW = "1.4" Fetch = "0.1" Flux = "0.12" +KernelAbstractions = "0.7" MAT = "0.10" Tullio = "0.3" Zygote = "0.6" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 971ca0b2..fa602ff4 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -59,7 +59,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" [[NeuralOperators]] path = ".." -uuid = "9ab867d4-5049-4b07-85bc-95379d8d6d9c" +uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" version = "0.1.0" [[Parsers]] diff --git a/docs/Project.toml b/docs/Project.toml index fd3392d4..de8c9086 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,3 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -NeuralOperators = "9ab867d4-5049-4b07-85bc-95379d8d6d9c" +NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b" diff --git a/example/burgers.jl b/example/burgers.jl index 41ef23b0..3dc24775 100644 --- a/example/burgers.jl +++ b/example/burgers.jl @@ -1,14 +1,14 @@ using NeuralOperators using Flux -# using CUDA +using CUDA -# if has_cuda() -# @info "CUDA is on" -# device = gpu -# CUDA.allowscalar(false) -# else +if has_cuda() + @info "CUDA is on" + device = gpu + CUDA.allowscalar(false) +else device = cpu -# end +end m = FourierNeuralOperator() |> device loss(๐ฑ, ๐ฒ) = sum(abs2, ๐ฒ .- m(๐ฑ)) / size(๐ฑ)[end] @@ -24,8 +24,9 @@ n_test = 40 loader_test = Flux.DataLoader((๐ฑ_test, ๐ฒ_test), batchsize=20, shuffle=false) function loss_test() - l = 0 + l = 0f0 for (๐ฑ, ๐ฒ) in loader_test + ๐ฑ, ๐ฒ = device(๐ฑ), device(๐ฒ) l += loss(๐ฑ, ๐ฒ) end @info "loss: $(l/length(loader_test))" @@ -33,4 +34,4 @@ end data = [(๐ฑ, ๐ฒ) for (๐ฑ, ๐ฒ) in loader_train] |> device opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3)) -Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 10))) +Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 5))) diff --git a/src/NeuralOperators.jl b/src/NeuralOperators.jl index ddd13b59..bffee18f 100644 --- a/src/NeuralOperators.jl +++ b/src/NeuralOperators.jl @@ -6,6 +6,9 @@ module NeuralOperators using Flux using FFTW using Tullio + using CUDA + using CUDAKernels + using KernelAbstractions using Zygote function __init__() diff --git a/src/fourier.jl b/src/fourier.jl index 38958b71..8206b808 100644 --- a/src/fourier.jl +++ b/src/fourier.jl @@ -1,7 +1,6 @@ export SpectralConv1d, - FourierOperator, - FNO + FourierOperator struct SpectralConv1d{T, S} weight::T @@ -28,26 +27,31 @@ function SpectralConv1d( return Chain( x -> Zygote.hook(real, x), - SpectralConv1d(weights, in_chs, out_chs, modes, ฯƒ) + SpectralConv1d(weights, in_chs, out_chs, modes, ฯƒ), ) end Flux.@functor SpectralConv1d +t(๐ฑ) = @tullio ๐ฑแต€[i, j, k] := ๐ฑ[j, i, k] +ein_mul(๐ฑโ‚, ๐ฑโ‚‚) = @tullio ๐ฒ[m, o, b] := ๐ฑโ‚[m, i, b] * ๐ฑโ‚‚[o, i, m] + function (m::SpectralConv1d)(๐ฑ::AbstractArray) - ๐ฑ_fft = fft(๐ฑ, 2) # [in_chs, x, batch] - ๐ฑ_selected = ๐ฑ_fft[:, 1:m.modes, :] # [in_chs, modes, batch] + ๐ฑแต€ = t(๐ฑ) # [x, in_chs, batch] <- [in_chs, x, batch] + ๐ฑ_fft = fft(๐ฑแต€, 1) # [x, in_chs, batch] + ๐ฑ_selected = ๐ฑ_fft[1:m.modes, :, :] # [modes, in_chs, batch] - # [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes] - @tullio ๐ฑ_weighted[o, m, b] := ๐ฑ_selected[i, m, b] * m.weight[o, i, m] + # [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes] + ๐ฑ_weighted = ein_mul(๐ฑ_selected, m.weight) - s = size(๐ฑ_weighted) - d = size(๐ฑ, 2) - m.modes - ๐ฑ_padded = cat(๐ฑ_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2) + s = size(๐ฑ_weighted)[2:end] + d = size(๐ฑแต€, 1) - m.modes + ๐ฑ_padded = cat(๐ฑ_weighted, zeros(ComplexF32, d, s...), dims=1) - ๐ฑ_out = ifft(๐ฑ_padded, 2) + ๐ฑ_out = ifft(๐ฑ_padded, 1) # [x, out_chs, batch] + ๐ฑ_outแต€ = t(๐ฑ_out) # [out_chs, x, batch] <- [x, out_chs, batch] - return m.ฯƒ.(real(๐ฑ_out)) + return m.ฯƒ.(real(๐ฑ_outแต€)) end function FourierOperator(