Skip to content
This repository has been archived by the owner on Sep 28, 2024. It is now read-only.

support GPU #7

Merged
merged 6 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@ authors = ["JingYu Ning <[email protected]> and contributors"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "3.3"
CUDAKernels = "0.3"
DataDeps = "0.7"
FFTW = "1.4"
Fetch = "0.1"
Flux = "0.12"
KernelAbstractions = "0.7"
MAT = "0.10"
Tullio = "0.3"
Zygote = "0.6"
Expand Down
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[NeuralOperators]]
path = ".."
uuid = "9ab867d4-5049-4b07-85bc-95379d8d6d9c"
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
version = "0.1.0"

[[Parsers]]
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
NeuralOperators = "9ab867d4-5049-4b07-85bc-95379d8d6d9c"
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
19 changes: 10 additions & 9 deletions example/burgers.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using NeuralOperators
using Flux
# using CUDA
using CUDA

# if has_cuda()
# @info "CUDA is on"
# device = gpu
# CUDA.allowscalar(false)
# else
if has_cuda()
@info "CUDA is on"
device = gpu
CUDA.allowscalar(false)
else
device = cpu
# end
end

m = FourierNeuralOperator() |> device
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
Expand All @@ -24,13 +24,14 @@ n_test = 40
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=20, shuffle=false)

function loss_test()
l = 0
l = 0f0
for (𝐱, 𝐲) in loader_test
𝐱, 𝐲 = device(𝐱), device(𝐲)
l += loss(𝐱, 𝐲)
end
@info "loss: $(l/length(loader_test))"
end

data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 10)))
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 5)))
3 changes: 3 additions & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ module NeuralOperators
using Flux
using FFTW
using Tullio
using CUDA
using CUDAKernels
using KernelAbstractions
using Zygote

function __init__()
Expand Down
28 changes: 16 additions & 12 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export
SpectralConv1d,
FourierOperator,
FNO
FourierOperator

struct SpectralConv1d{T, S}
weight::T
Expand All @@ -28,26 +27,31 @@ function SpectralConv1d(

return Chain(
x -> Zygote.hook(real, x),
SpectralConv1d(weights, in_chs, out_chs, modes, σ)
SpectralConv1d(weights, in_chs, out_chs, modes, σ),
)
end

Flux.@functor SpectralConv1d

t(𝐱) = @tullio 𝐱ᵀ[i, j, k] := 𝐱[j, i, k]
ein_mul(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]

function (m::SpectralConv1d)(𝐱::AbstractArray)
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]
𝐱ᵀ = t(𝐱) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]

# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
𝐱_weighted = ein_mul(𝐱_selected, m.weight)

s = size(𝐱_weighted)
d = size(𝐱, 2) - m.modes
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)
s = size(𝐱_weighted)[2:end]
d = size(𝐱ᵀ, 1) - m.modes
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, d, s...), dims=1)

𝐱_out = ifft(𝐱_padded, 2)
𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
𝐱_outᵀ = t(𝐱_out) # [out_chs, x, batch] <- [x, out_chs, batch]

return m.σ.(real(𝐱_out))
return m.σ.(real(𝐱_outᵀ))
end

function FourierOperator(
Expand Down